diff --git a/amaranth-stubs b/amaranth-stubs index 480a38672..e8efebb9d 160000 --- a/amaranth-stubs +++ b/amaranth-stubs @@ -1 +1 @@ -Subproject commit 480a38672c6e84f4827a458e50e06d69f5588713 +Subproject commit e8efebb9dfc8f89a93b92b290eb9f9b11899ed0d diff --git a/coreblocks/core_structs/rob.py b/coreblocks/core_structs/rob.py index 72a3b291d..c0cc4ac13 100644 --- a/coreblocks/core_structs/rob.py +++ b/coreblocks/core_structs/rob.py @@ -1,5 +1,4 @@ from amaranth import * -from amaranth.lib.data import View import amaranth.lib.memory as memory from transactron import Method, Transaction, def_method, TModule from transactron.lib.metrics import * @@ -19,7 +18,7 @@ def __init__(self, gen_params: GenParams) -> None: self.retire = Method() self.done = Array(Signal() for _ in range(2**self.params.rob_entries_bits)) self.exception = Array(Signal() for _ in range(2**self.params.rob_entries_bits)) - self.data = memory.Memory(shape=layouts.data_layout.size, depth=2**self.params.rob_entries_bits, init=[]) + self.data = memory.Memory(shape=layouts.data_layout, depth=2**self.params.rob_entries_bits, init=[]) self.get_indices = Method(o=layouts.get_indices, nonexclusive=True) self.perf_rob_wait_time = FIFOLatencyMeasurer( @@ -54,8 +53,8 @@ def elaborate(self, platform): @def_method(m, self.peek, ready=peek_possible) def _(): - return { # remove View after Amaranth upgrade - "rob_data": View(self.params.get(ROBLayouts).data_layout, read_port.data), + return { + "rob_data": read_port.data, "rob_id": start_idx, "exception": self.exception[start_idx], } diff --git a/pytest.ini b/pytest.ini index 142b00abe..c2bc22f2b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,5 +4,6 @@ testpaths = tests norecursedirs = '*.egg', '.*', 'build', 'dist', 'venv', '__traces__', '__pycache__' filterwarnings = + ignore:cannot collect test class 'TestbenchContext':pytest.PytestCollectionWarning ignore:cannot collect test class 'TestbenchIO':pytest.PytestCollectionWarning ignore:No files were found in testpaths:pytest.PytestConfigWarning: diff --git a/requirements.txt b/requirements.txt index d3260abbb..376b0f0b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ ./amaranth-stubs/ # can't use -e -- pyright doesn't see the stubs then :( amaranth-yosys==0.40.0.0.post100 -git+https://github.com/amaranth-lang/amaranth@5e59189c2b8689a453891e17e378bf73806efdd3 +git+https://github.com/amaranth-lang/amaranth@994fa815995b1ac5b3c708915dcece2a45796569 dataclasses-json==0.6.3 diff --git a/test/backend/test_annoucement.py b/test/backend/test_annoucement.py index e6fd56fa3..71b83a319 100644 --- a/test/backend/test_annoucement.py +++ b/test/backend/test_annoucement.py @@ -8,7 +8,7 @@ from coreblocks.interface.layouts import * from coreblocks.params import GenParams from coreblocks.params.configurations import test_core_config -from transactron.testing import TestCaseWithSimulator, TestbenchIO +from transactron.testing import TestCaseWithSimulator, TestbenchIO, TestbenchContext class BackendTestCircuit(Elaboratable): @@ -104,32 +104,33 @@ def generate_producer(self, i: int): results to its output FIFO. This records will be next serialized by FUArbiter. """ - def producer(): + async def producer(sim: TestbenchContext): inputs = self.fu_inputs[i] for rob_id, result, rp_dst in inputs: io: TestbenchIO = self.m.fu_fifo_ins[i] - yield from io.call_init(rob_id=rob_id, result=result, rp_dst=rp_dst) - yield from self.random_wait(self.max_wait) + io.call_init(sim, rob_id=rob_id, result=result, rp_dst=rp_dst) + await self.random_wait(sim, self.max_wait) self.producer_end[i] = True return producer - def consumer(self): - yield from self.m.rs_announce_val_tbio.enable() - yield from self.m.rob_mark_done_tbio.enable() + async def consumer(self, sim: TestbenchContext): + # TODO: this test doesn't do anything, fix it! + self.m.rs_announce_val_tbio.enable(sim) + self.m.rob_mark_done_tbio.enable(sim) while reduce(and_, self.producer_end, True): # All 3 methods (in RF, RS and ROB) need to be enabled for the result # announcement transaction to take place. We want to have at least one # method disabled most of the time, so that the transaction is performed # only when we enable it inside the loop. Otherwise the transaction could # get executed at any time, particularly when we wouldn't be monitoring it - yield from self.m.rf_announce_val_tbio.enable() + self.m.rf_announce_val_tbio.enable(sim) - rf_result = yield from self.m.rf_announce_val_tbio.method_argument() - rs_result = yield from self.m.rs_announce_val_tbio.method_argument() - rob_result = yield from self.m.rob_mark_done_tbio.method_argument() + rf_result = self.m.rf_announce_val_tbio.get_outputs(sim) + rs_result = self.m.rs_announce_val_tbio.get_outputs(sim) + rob_result = self.m.rob_mark_done_tbio.get_outputs(sim) - yield from self.m.rf_announce_val_tbio.disable() + self.m.rf_announce_val_tbio.disable(sim) assert rf_result is not None assert rs_result is not None @@ -144,20 +145,20 @@ def consumer(self): del self.expected_output[t] else: self.expected_output[t] -= 1 - yield from self.random_wait(self.max_wait) + await self.random_wait(sim, self.max_wait) def test_one_out(self): self.fu_count = 1 self.initialize() with self.run_simulation(self.m) as sim: - sim.add_process(self.consumer) + sim.add_testbench(self.consumer) for i in range(self.fu_count): - sim.add_process(self.generate_producer(i)) + sim.add_testbench(self.generate_producer(i)) def test_many_out(self): self.fu_count = 4 self.initialize() with self.run_simulation(self.m) as sim: - sim.add_process(self.consumer) + sim.add_testbench(self.consumer) for i in range(self.fu_count): - sim.add_process(self.generate_producer(i)) + sim.add_testbench(self.generate_producer(i)) diff --git a/test/backend/test_retirement.py b/test/backend/test_retirement.py index a155caf04..cf039ed13 100644 --- a/test/backend/test_retirement.py +++ b/test/backend/test_retirement.py @@ -12,6 +12,7 @@ from coreblocks.params import GenParams from coreblocks.interface.layouts import ROBLayouts, RFLayouts, SchedulerLayouts from coreblocks.params.configurations import test_core_config +from transactron.lib.adapters import AdapterTrans from transactron.testing import * from collections import deque @@ -120,42 +121,46 @@ def setup_method(self): # (and the retirement code doesn't have any special behaviour to handle these cases), but in this simple # test we don't care to make sure that the randomly generated inputs are correct in this way. - @def_method_mock(lambda self: self.retc.mock_rob_retire, enable=lambda self: bool(self.submit_q), sched_prio=1) + @def_method_mock(lambda self: self.retc.mock_rob_retire, enable=lambda self: bool(self.submit_q)) def retire_process(self): - self.submit_q.popleft() + @MethodMock.effect + def eff(): + self.submit_q.popleft() @def_method_mock(lambda self: self.retc.mock_rob_peek, enable=lambda self: bool(self.submit_q)) def peek_process(self): return self.submit_q[0] - def free_reg_process(self): + async def free_reg_process(self, sim: TestbenchContext): while self.rf_exp_q: - reg = yield from self.retc.free_rf_adapter.call() + reg = await self.retc.free_rf_adapter.call(sim) assert reg["reg_id"] == self.rf_exp_q.popleft() - def rat_process(self): + async def rat_process(self, sim: TestbenchContext): while self.rat_map_q: current_map = self.rat_map_q.popleft() wait_cycles = 0 # this test waits for next rat pair to be correctly set and will timeout if that assignment fails - while (yield self.retc.rat.entries[current_map["rl_dst"]]) != current_map["rp_dst"]: + while sim.get(self.retc.rat.entries[current_map["rl_dst"]]) != current_map["rp_dst"]: wait_cycles += 1 if wait_cycles >= self.cycles + 10: assert False, "RAT entry was not updated" - yield Tick() + await sim.tick() assert not self.submit_q assert not self.rf_free_q - def precommit_process(self): + async def precommit_process(self, sim: TestbenchContext): while self.precommit_q: - info = yield from self.retc.precommit_adapter.call_try(rob_id=self.precommit_q[0]) + info = await self.retc.precommit_adapter.call_try(sim, rob_id=self.precommit_q[0]) assert info is not None assert info["side_fx"] self.precommit_q.popleft() - @def_method_mock(lambda self: self.retc.mock_rf_free, sched_prio=2) + @def_method_mock(lambda self: self.retc.mock_rf_free) def rf_free_process(self, reg_id): - assert reg_id == self.rf_free_q.popleft() + @MethodMock.effect + def eff(): + assert reg_id == self.rf_free_q.popleft() @def_method_mock(lambda self: self.retc.mock_exception_cause) def exception_cause_process(self): @@ -174,7 +179,7 @@ def mock_trap_entry_process(self): pass @def_method_mock(lambda self: self.retc.mock_fetch_continue) - def mock_fetch_continue_process(self): + def mock_fetch_continue_process(self, pc): pass @def_method_mock(lambda self: self.retc.mock_async_interrupt_cause) @@ -184,6 +189,6 @@ def mock_async_interrupt_cause(self): def test_rand(self): self.retc = RetirementTestCircuit(self.gen_params) with self.run_simulation(self.retc) as sim: - sim.add_process(self.free_reg_process) - sim.add_process(self.rat_process) - sim.add_process(self.precommit_process) + sim.add_testbench(self.free_reg_process) + sim.add_testbench(self.rat_process) + sim.add_testbench(self.precommit_process) diff --git a/test/cache/test_icache.py b/test/cache/test_icache.py index a52d75f35..88d44450a 100644 --- a/test/cache/test_icache.py +++ b/test/cache/test_icache.py @@ -3,7 +3,6 @@ import random from amaranth import Elaboratable, Module -from amaranth.sim import Passive, Settle, Tick from amaranth.utils import exact_log2 from transactron.lib import AdapterTrans, Adapter @@ -15,7 +14,10 @@ from coreblocks.params.configurations import test_core_config from coreblocks.cache.refiller import SimpleCommonBusCacheRefiller -from transactron.testing import TestCaseWithSimulator, TestbenchIO, def_method_mock, RecordIntDictRet +from transactron.testing import TestCaseWithSimulator, TestbenchIO, def_method_mock, TestbenchContext +from transactron.testing.functions import MethodData +from transactron.testing.method_mock import MethodMock +from transactron.testing.testbenchio import CallTrigger from ..peripherals.test_wishbone import WishboneInterfaceWrapper @@ -98,35 +100,29 @@ def setup_method(self) -> None: self.bad_addresses.add(bad_addr) self.bad_fetch_blocks.add(bad_addr & ~(self.cp.fetch_block_bytes - 1)) - def wishbone_slave(self): - yield Passive() - + async def wishbone_slave(self, sim: TestbenchContext): while True: - yield from self.test_module.wb_ctrl.slave_wait() + adr, *_ = await self.test_module.wb_ctrl.slave_wait(sim) # Wishbone is addressing words, so we need to shift it a bit to get the real address. - addr = (yield self.test_module.wb_ctrl.wb.adr) << exact_log2(self.cp.word_width_bytes) + addr = adr << exact_log2(self.cp.word_width_bytes) - yield Tick() - while random.random() < 0.5: - yield Tick() + await self.random_wait_geom(sim, 0.5) err = 1 if addr in self.bad_addresses else 0 data = random.randrange(2**self.gen_params.isa.xlen) self.mem[addr] = data - yield from self.test_module.wb_ctrl.slave_respond(data, err=err) - - yield Settle() + await self.test_module.wb_ctrl.slave_respond(sim, data, err=err) - def refiller_process(self): + async def refiller_process(self, sim: TestbenchContext): while self.requests: req_addr = self.requests.pop() - yield from self.test_module.start_refill.call(addr=req_addr) + await self.test_module.start_refill.call(sim, addr=req_addr) for i in range(self.cp.fetch_blocks_in_line): - ret = yield from self.test_module.accept_refill.call() + ret = await self.test_module.accept_refill.call(sim) cur_addr = req_addr + i * self.cp.fetch_block_bytes @@ -149,8 +145,8 @@ def refiller_process(self): def test(self): with self.run_simulation(self.test_module) as sim: - sim.add_process(self.wishbone_slave) - sim.add_process(self.refiller_process) + sim.add_testbench(self.wishbone_slave, background=True) + sim.add_testbench(self.refiller_process) class ICacheBypassTestCircuit(Elaboratable): @@ -220,17 +216,14 @@ def load_or_gen_mem(self, addr: int): self.mem[addr] = random.randrange(2**self.gen_params.isa.ilen) return self.mem[addr] - def wishbone_slave(self): - yield Passive() - + async def wishbone_slave(self, sim: TestbenchContext): while True: - yield from self.m.wb_ctrl.slave_wait() + adr, *_ = await self.m.wb_ctrl.slave_wait(sim) # Wishbone is addressing words, so we need to shift it a bit to get the real address. - addr = (yield self.m.wb_ctrl.wb.adr) << exact_log2(self.cp.word_width_bytes) + addr = adr << exact_log2(self.cp.word_width_bytes) - while random.random() < 0.5: - yield Tick() + await self.random_wait_geom(sim, 0.5) err = 1 if addr in self.bad_addrs else 0 @@ -238,19 +231,16 @@ def wishbone_slave(self): if self.gen_params.isa.xlen == 64: data = self.load_or_gen_mem(addr + 4) << 32 | data - yield from self.m.wb_ctrl.slave_respond(data, err=err) + await self.m.wb_ctrl.slave_respond(sim, data, err=err) - yield Settle() - - def user_process(self): + async def user_process(self, sim: TestbenchContext): while self.requests: req_addr = self.requests.popleft() & ~(self.cp.fetch_block_bytes - 1) - yield from self.m.issue_req.call(addr=req_addr) + await self.m.issue_req.call(sim, addr=req_addr) - while random.random() < 0.5: - yield Tick() + await self.random_wait_geom(sim, 0.5) - ret = yield from self.m.accept_res.call() + ret = await self.m.accept_res.call(sim) if (req_addr & ~(self.cp.word_width_bytes - 1)) in self.bad_addrs: assert ret["error"] @@ -262,13 +252,12 @@ def user_process(self): data |= self.mem[req_addr + 4] << 32 assert ret["fetch_block"] == data - while random.random() < 0.5: - yield Tick() + await self.random_wait_geom(sim, 0.5) def test(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.wishbone_slave) - sim.add_process(self.user_process) + sim.add_testbench(self.wishbone_slave, background=True) + sim.add_testbench(self.user_process) class MockedCacheRefiller(Elaboratable, CacheRefillerInterface): @@ -328,6 +317,7 @@ def setup_method(self) -> None: self.bad_addrs = set() self.bad_cache_lines = set() self.refill_requests = deque() + self.refill_block_cnt = 0 self.issued_requests = deque() self.accept_refill_request = True @@ -351,12 +341,17 @@ def init_module(self, ways, sets) -> None: @def_method_mock(lambda self: self.m.refiller.start_refill_mock, enable=lambda self: self.accept_refill_request) def start_refill_mock(self, addr): - self.refill_requests.append(addr) - self.refill_block_cnt = 0 - self.refill_in_fly = True - self.refill_addr = addr + @MethodMock.effect + def eff(): + self.refill_requests.append(addr) + self.refill_block_cnt = 0 + self.refill_in_fly = True + self.refill_addr = addr + + def enen(self): + return self.refill_in_fly - @def_method_mock(lambda self: self.m.refiller.accept_refill_mock, enable=lambda self: self.refill_in_fly) + @def_method_mock(lambda self: self.m.refiller.accept_refill_mock, enable=enen) def accept_refill_mock(self): addr = self.refill_addr + self.refill_block_cnt * self.cp.fetch_block_bytes @@ -367,12 +362,14 @@ def accept_refill_mock(self): if addr + i in self.bad_addrs: bad_addr = True - self.refill_block_cnt += 1 + last = self.refill_block_cnt + 1 == self.cp.fetch_blocks_in_line or bad_addr - last = self.refill_block_cnt == self.cp.fetch_blocks_in_line or bad_addr + @MethodMock.effect + def eff(): + self.refill_block_cnt += 1 - if last: - self.refill_in_fly = False + if last: + self.refill_in_fly = False return { "addr": addr, @@ -390,18 +387,19 @@ def add_bad_addr(self, addr: int): self.bad_addrs.add(addr) self.bad_cache_lines.add(addr & ~((1 << self.cp.offset_bits) - 1)) - def send_req(self, addr: int): + async def send_req(self, sim: TestbenchContext, addr: int): self.issued_requests.append(addr) - yield from self.m.issue_req.call(addr=addr) + await self.m.issue_req.call(sim, addr=addr) - def expect_resp(self, wait=False): - yield Settle() + async def expect_resp(self, sim: TestbenchContext, wait=False): if wait: - yield from self.m.accept_res.wait_until_done() + *_, resp = await self.m.accept_res.sample_outputs_until_done(sim) + else: + *_, resp = await self.m.accept_res.sample_outputs(sim) - self.assert_resp((yield from self.m.accept_res.get_outputs())) + self.assert_resp(resp) - def assert_resp(self, resp: RecordIntDictRet): + def assert_resp(self, resp: MethodData): addr = self.issued_requests.popleft() & ~(self.cp.fetch_block_bytes - 1) if (addr & ~((1 << self.cp.offset_bits) - 1)) in self.bad_cache_lines: @@ -417,343 +415,321 @@ def assert_resp(self, resp: RecordIntDictRet): def expect_refill(self, addr: int): assert self.refill_requests.popleft() == addr - def call_cache(self, addr: int): - yield from self.send_req(addr) - yield from self.m.accept_res.enable() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() + async def call_cache(self, sim: TestbenchContext, addr: int): + await self.send_req(sim, addr) + self.m.accept_res.enable(sim) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) def test_1_way(self): self.init_module(1, 4) - def cache_user_process(): + async def cache_user_process(sim: TestbenchContext): # The first request should cause a cache miss - yield from self.call_cache(0x00010004) + await self.call_cache(sim, 0x00010004) self.expect_refill(0x00010000) # Accesses to the same cache line shouldn't cause a cache miss for i in range(self.cp.fetch_blocks_in_line): - yield from self.call_cache(0x00010000 + i * self.cp.fetch_block_bytes) + await self.call_cache(sim, 0x00010000 + i * self.cp.fetch_block_bytes) assert len(self.refill_requests) == 0 # Now go beyond the first cache line - yield from self.call_cache(0x00010000 + self.cp.line_size_bytes) + await self.call_cache(sim, 0x00010000 + self.cp.line_size_bytes) self.expect_refill(0x00010000 + self.cp.line_size_bytes) # Trigger cache aliasing - yield from self.call_cache(0x00020000) - yield from self.call_cache(0x00010000) + await self.call_cache(sim, 0x00020000) + await self.call_cache(sim, 0x00010000) self.expect_refill(0x00020000) self.expect_refill(0x00010000) # Fill the whole cache for i in range(0, self.cp.line_size_bytes * self.cp.num_of_sets, 4): - yield from self.call_cache(i) + await self.call_cache(sim, i) for i in range(self.cp.num_of_sets): self.expect_refill(i * self.cp.line_size_bytes) # Now do some accesses within the cached memory for i in range(50): - yield from self.call_cache(random.randrange(0, self.cp.line_size_bytes * self.cp.num_of_sets, 4)) + await self.call_cache(sim, random.randrange(0, self.cp.line_size_bytes * self.cp.num_of_sets, 4)) assert len(self.refill_requests) == 0 with self.run_simulation(self.m) as sim: - sim.add_process(cache_user_process) + sim.add_testbench(cache_user_process) def test_2_way(self): self.init_module(2, 4) - def cache_process(): + async def cache_process(sim: TestbenchContext): # Fill the first set of both ways - yield from self.call_cache(0x00010000) - yield from self.call_cache(0x00020000) + await self.call_cache(sim, 0x00010000) + await self.call_cache(sim, 0x00020000) self.expect_refill(0x00010000) self.expect_refill(0x00020000) # And now both lines should be in the cache - yield from self.call_cache(0x00010004) - yield from self.call_cache(0x00020004) + await self.call_cache(sim, 0x00010004) + await self.call_cache(sim, 0x00020004) assert len(self.refill_requests) == 0 with self.run_simulation(self.m) as sim: - sim.add_process(cache_process) + sim.add_testbench(cache_process) # Tests whether the cache is fully pipelined and the latency between requests and response is exactly one cycle. def test_pipeline(self): self.init_module(2, 4) - def cache_process(): + async def cache_process(sim: TestbenchContext): # Fill the cache for i in range(self.cp.num_of_sets): addr = 0x00010000 + i * self.cp.line_size_bytes - yield from self.call_cache(addr) + await self.call_cache(sim, addr) self.expect_refill(addr) - yield from self.tick(5) + await self.tick(sim, 4) # Create a stream of requests to ensure the pipeline is working - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) for i in range(0, self.cp.num_of_sets * self.cp.line_size_bytes, 4): addr = 0x00010000 + i self.issued_requests.append(addr) # Send the request - yield from self.m.issue_req.call_init(addr=addr) - yield Settle() - assert (yield from self.m.issue_req.done()) + ret = await self.m.issue_req.call_try(sim, addr=addr) + assert ret is not None # After a cycle the response should be ready - yield Tick() - yield from self.expect_resp() - yield from self.m.issue_req.disable() + await self.expect_resp(sim) - yield Tick() - yield from self.m.accept_res.disable() + self.m.accept_res.disable(sim) - yield from self.tick(5) + await self.tick(sim, 4) # Check how the cache handles queuing the requests - yield from self.send_req(addr=0x00010000 + 3 * self.cp.line_size_bytes) - yield from self.send_req(addr=0x00010004) + await self.send_req(sim, addr=0x00010000 + 3 * self.cp.line_size_bytes) + await self.send_req(sim, addr=0x00010004) # Wait a few cycles. There are two requests queued - yield from self.tick(5) + await self.tick(sim, 4) - yield from self.m.accept_res.enable() - yield from self.expect_resp() - yield Tick() - yield from self.expect_resp() - yield from self.send_req(addr=0x0001000C) - yield from self.expect_resp() + self.m.accept_res.enable(sim) + await self.expect_resp( + sim, + ) + await self.expect_resp( + sim, + ) + await self.send_req(sim, addr=0x0001000C) + await self.expect_resp( + sim, + ) - yield Tick() - yield from self.m.accept_res.disable() + self.m.accept_res.disable(sim) - yield from self.tick(5) + await self.tick(sim, 4) # Schedule two requests, the first one causing a cache miss - yield from self.send_req(addr=0x00020000) - yield from self.send_req(addr=0x00010000 + self.cp.line_size_bytes) + await self.send_req(sim, addr=0x00020000) + await self.send_req(sim, addr=0x00010000 + self.cp.line_size_bytes) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.expect_resp() - yield Tick() - yield from self.m.accept_res.disable() + await self.expect_resp(sim, wait=True) + await self.expect_resp( + sim, + ) + self.m.accept_res.disable(sim) - yield from self.tick(3) + await self.tick(sim, 2) # Schedule two requests, the second one causing a cache miss - yield from self.send_req(addr=0x00020004) - yield from self.send_req(addr=0x00030000 + self.cp.line_size_bytes) + await self.send_req(sim, addr=0x00020004) + await self.send_req(sim, addr=0x00030000 + self.cp.line_size_bytes) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp() - yield Tick() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() + await self.expect_resp( + sim, + ) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) - yield from self.tick(3) + await self.tick(sim, 2) # Schedule two requests, both causing a cache miss - yield from self.send_req(addr=0x00040000) - yield from self.send_req(addr=0x00050000 + self.cp.line_size_bytes) + await self.send_req(sim, addr=0x00040000) + await self.send_req(sim, addr=0x00050000 + self.cp.line_size_bytes) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() + await self.expect_resp(sim, wait=True) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) with self.run_simulation(self.m) as sim: - sim.add_process(cache_process) + sim.add_testbench(cache_process) def test_flush(self): self.init_module(2, 4) - def cache_process(): + async def cache_process(sim: TestbenchContext): # Fill the whole cache for s in range(self.cp.num_of_sets): for w in range(self.cp.num_of_ways): addr = w * 0x00010000 + s * self.cp.line_size_bytes - yield from self.call_cache(addr) + await self.call_cache(sim, addr) self.expect_refill(addr) # Everything should be in the cache for s in range(self.cp.num_of_sets): for w in range(self.cp.num_of_ways): addr = w * 0x00010000 + s * self.cp.line_size_bytes - yield from self.call_cache(addr) + await self.call_cache(sim, addr) assert len(self.refill_requests) == 0 - yield from self.m.flush_cache.call() + await self.m.flush_cache.call(sim) # The cache should be empty for s in range(self.cp.num_of_sets): for w in range(self.cp.num_of_ways): addr = w * 0x00010000 + s * self.cp.line_size_bytes - yield from self.call_cache(addr) + await self.call_cache(sim, addr) self.expect_refill(addr) # Try to flush during refilling the line - yield from self.send_req(0x00030000) - yield from self.m.flush_cache.call() + await self.send_req(sim, 0x00030000) + await self.m.flush_cache.call(sim) # We still should be able to accept the response for the last request - self.assert_resp((yield from self.m.accept_res.call())) + self.assert_resp(await self.m.accept_res.call(sim)) self.expect_refill(0x00030000) - yield from self.call_cache(0x00010000) + await self.call_cache(sim, 0x00010000) self.expect_refill(0x00010000) - yield Tick() - # Try to execute issue_req and flush_cache methods at the same time - yield from self.m.issue_req.call_init(addr=0x00010000) self.issued_requests.append(0x00010000) - yield from self.m.flush_cache.call_init() - yield Settle() - assert not (yield from self.m.issue_req.done()) - assert (yield from self.m.flush_cache.done()) - yield Tick() - yield from self.m.flush_cache.call_do() - yield from self.m.issue_req.call_do() - self.assert_resp((yield from self.m.accept_res.call())) + issue_req_res, flush_cache_res = ( + await CallTrigger(sim).call(self.m.issue_req, addr=0x00010000).call(self.m.flush_cache) + ) + assert issue_req_res is None + assert flush_cache_res is not None + await self.m.issue_req.call(sim, addr=0x00010000) + self.assert_resp(await self.m.accept_res.call(sim)) self.expect_refill(0x00010000) - yield Tick() - # Schedule two requests and then flush - yield from self.send_req(0x00000000 + self.cp.line_size_bytes) - yield from self.send_req(0x00010000) + await self.send_req(sim, 0x00000000 + self.cp.line_size_bytes) + await self.send_req(sim, 0x00010000) - yield from self.m.flush_cache.call_init() - yield Tick() + res = await self.m.flush_cache.call_try(sim) # We cannot flush until there are two pending requests - assert not (yield from self.m.flush_cache.done()) - yield Tick() - yield from self.m.flush_cache.disable() - yield Tick() + assert res is None + res = await self.m.flush_cache.call_try(sim) + assert res is None # Accept the first response - self.assert_resp((yield from self.m.accept_res.call())) + self.assert_resp(await self.m.accept_res.call(sim)) - yield from self.m.flush_cache.call() + await self.m.flush_cache.call(sim) # And accept the second response ensuring that we got old data - self.assert_resp((yield from self.m.accept_res.call())) + self.assert_resp(await self.m.accept_res.call(sim)) self.expect_refill(0x00000000 + self.cp.line_size_bytes) # Just make sure that the line is truly flushed - yield from self.call_cache(0x00010000) + await self.call_cache(sim, 0x00010000) self.expect_refill(0x00010000) with self.run_simulation(self.m) as sim: - sim.add_process(cache_process) + sim.add_testbench(cache_process) def test_errors(self): self.init_module(1, 4) - def cache_process(): + async def cache_process(sim: TestbenchContext): self.add_bad_addr(0x00010000) # Bad addr at the beggining of the line self.add_bad_addr(0x00020008) # Bad addr in the middle of the line self.add_bad_addr( 0x00030000 + self.cp.line_size_bytes - self.cp.word_width_bytes ) # Bad addr at the end of the line - yield from self.call_cache(0x00010008) + await self.call_cache(sim, 0x00010008) self.expect_refill(0x00010000) # Requesting a bad addr again should retrigger refill - yield from self.call_cache(0x00010008) + await self.call_cache(sim, 0x00010008) self.expect_refill(0x00010000) - yield from self.call_cache(0x00020000) + await self.call_cache(sim, 0x00020000) self.expect_refill(0x00020000) - yield from self.call_cache(0x00030008) + await self.call_cache(sim, 0x00030008) self.expect_refill(0x00030000) # Test how pipelining works with errors - yield from self.m.accept_res.disable() - yield Tick() + self.m.accept_res.disable(sim) # Schedule two requests, the first one causing an error - yield from self.send_req(addr=0x00020000) - yield from self.send_req(addr=0x00011000) + await self.send_req(sim, addr=0x00020000) + await self.send_req(sim, addr=0x00011000) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() + await self.expect_resp(sim, wait=True) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) - yield from self.tick(3) + await self.tick(sim, 3) # Schedule two requests, the second one causing an error - yield from self.send_req(addr=0x00021004) - yield from self.send_req(addr=0x00030000) + await self.send_req(sim, addr=0x00021004) + await self.send_req(sim, addr=0x00030000) - yield from self.tick(10) + await self.tick(sim, 10) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() + await self.expect_resp(sim, wait=True) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) - yield from self.tick(3) + await self.tick(sim, 3) # Schedule two requests, both causing an error - yield from self.send_req(addr=0x00020000) - yield from self.send_req(addr=0x00010000) + await self.send_req(sim, addr=0x00020000) + await self.send_req(sim, addr=0x00010000) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() - yield Tick() + await self.expect_resp(sim, wait=True) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) # The second request will cause an error - yield from self.send_req(addr=0x00021004) - yield from self.send_req(addr=0x00030000) + await self.send_req(sim, addr=0x00021004) + await self.send_req(sim, addr=0x00030000) - yield from self.tick(10) + await self.tick(sim, 10) # Accept the first response - yield from self.m.accept_res.enable() - yield from self.expect_resp(wait=True) - yield Tick() + self.m.accept_res.enable(sim) + await self.expect_resp(sim, wait=True) # Wait before accepting the second response - yield from self.m.accept_res.disable() - yield from self.tick(10) - yield from self.m.accept_res.enable() - yield from self.expect_resp(wait=True) - - yield Tick() + self.m.accept_res.disable(sim) + await self.tick(sim, 10) + self.m.accept_res.enable(sim) + await self.expect_resp(sim, wait=True) # This request should not cause an error - yield from self.send_req(addr=0x00011000) - yield from self.expect_resp(wait=True) + await self.send_req(sim, addr=0x00011000) + await self.expect_resp(sim, wait=True) with self.run_simulation(self.m) as sim: - sim.add_process(cache_process) + sim.add_testbench(cache_process) def test_random(self): self.init_module(4, 8) @@ -765,34 +741,28 @@ def test_random(self): if random.random() < 0.05: self.add_bad_addr(i) - def refiller_ctrl(): - yield Passive() - + async def refiller_ctrl(sim: TestbenchContext): while True: - yield from self.random_wait_geom(0.4) + await self.random_wait_geom(sim, 0.4) self.accept_refill_request = False - yield from self.random_wait_geom(0.7) + await self.random_wait_geom(sim, 0.7) self.accept_refill_request = True - def sender(): + async def sender(sim: TestbenchContext): for _ in range(iterations): - yield from self.send_req(random.randrange(0, max_addr, 4)) + await self.send_req(sim, random.randrange(0, max_addr, 4)) + await self.random_wait_geom(sim, 0.5) - while random.random() < 0.5: - yield Tick() - - def receiver(): + async def receiver(sim: TestbenchContext): for _ in range(iterations): while len(self.issued_requests) == 0: - yield Tick() - - self.assert_resp((yield from self.m.accept_res.call())) + await sim.tick() - while random.random() < 0.2: - yield Tick() + self.assert_resp(await self.m.accept_res.call(sim)) + await self.random_wait_geom(sim, 0.2) with self.run_simulation(self.m) as sim: - sim.add_process(sender) - sim.add_process(receiver) - sim.add_process(refiller_ctrl) + sim.add_testbench(sender) + sim.add_testbench(receiver) + sim.add_testbench(refiller_ctrl, background=True) diff --git a/test/core_structs/test_rat.py b/test/core_structs/test_rat.py index 01809d677..57093bb97 100644 --- a/test/core_structs/test_rat.py +++ b/test/core_structs/test_rat.py @@ -1,4 +1,4 @@ -from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit, TestbenchContext from coreblocks.core_structs.rat import FRAT, RRAT from coreblocks.params import GenParams @@ -7,6 +7,8 @@ from collections import deque from random import Random +from transactron.testing.testbenchio import CallTrigger + class TestFrontendRegisterAliasTable(TestCaseWithSimulator): def gen_input(self): @@ -18,14 +20,18 @@ def gen_input(self): self.to_execute_list.append({"rl": rl, "rp": rp, "rl_s1": rl_s1, "rl_s2": rl_s2}) - def do_rename(self): + async def do_rename(self, sim: TestbenchContext): for _ in range(self.test_steps): to_execute = self.to_execute_list.pop() - res = yield from self.m.rename.call( - rl_dst=to_execute["rl"], rp_dst=to_execute["rp"], rl_s1=to_execute["rl_s1"], rl_s2=to_execute["rl_s2"] + res = await self.m.rename.call( + sim, + rl_dst=to_execute["rl"], + rp_dst=to_execute["rp"], + rl_s1=to_execute["rl_s1"], + rl_s2=to_execute["rl_s2"], ) - assert res["rp_s1"] == self.expected_entries[to_execute["rl_s1"]] - assert res["rp_s2"] == self.expected_entries[to_execute["rl_s2"]] + assert res.rp_s1 == self.expected_entries[to_execute["rl_s1"]] + assert res.rp_s2 == self.expected_entries[to_execute["rl_s2"]] self.expected_entries[to_execute["rl"]] = to_execute["rp"] @@ -44,7 +50,7 @@ def test_single(self): self.gen_input() with self.run_simulation(m) as sim: - sim.add_process(self.do_rename) + sim.add_testbench(self.do_rename) class TestRetirementRegisterAliasTable(TestCaseWithSimulator): @@ -55,14 +61,17 @@ def gen_input(self): self.to_execute_list.append({"rl": rl, "rp": rp}) - def do_commit(self): + async def do_commit(self, sim: TestbenchContext): for _ in range(self.test_steps): to_execute = self.to_execute_list.pop() - yield from self.m.peek.call_init(rl_dst=to_execute["rl"]) - res = yield from self.m.commit.call(rl_dst=to_execute["rl"], rp_dst=to_execute["rp"]) - peek_res = yield from self.m.peek.call_do() - assert res["old_rp_dst"] == self.expected_entries[to_execute["rl"]] - assert peek_res["old_rp_dst"] == res["old_rp_dst"] + peek_res, res = ( + await CallTrigger(sim) + .call(self.m.peek, rl_dst=to_execute["rl"]) + .call(self.m.commit, rl_dst=to_execute["rl"], rp_dst=to_execute["rp"]) + ) + assert peek_res is not None and res is not None + assert res.old_rp_dst == self.expected_entries[to_execute["rl"]] + assert peek_res.old_rp_dst == res["old_rp_dst"] self.expected_entries[to_execute["rl"]] = to_execute["rp"] @@ -81,4 +90,4 @@ def test_single(self): self.gen_input() with self.run_simulation(m) as sim: - sim.add_process(self.do_commit) + sim.add_testbench(self.do_commit) diff --git a/test/core_structs/test_reorder_buffer.py b/test/core_structs/test_reorder_buffer.py index 0589e7db1..b1f935a81 100644 --- a/test/core_structs/test_reorder_buffer.py +++ b/test/core_structs/test_reorder_buffer.py @@ -1,6 +1,4 @@ -from amaranth.sim import Passive, Settle, Tick - -from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit, TestbenchContext from coreblocks.core_structs.rob import ReorderBuffer from coreblocks.params import GenParams @@ -9,54 +7,50 @@ from queue import Queue from random import Random +from transactron.testing.functions import data_const_to_dict + class TestReorderBuffer(TestCaseWithSimulator): - def gen_input(self): + async def gen_input(self, sim: TestbenchContext): for _ in range(self.test_steps): while self.regs_left_queue.empty(): - yield Tick() + await sim.tick() - while self.rand.random() < 0.5: - yield # to slow down puts + await self.random_wait_geom(sim, 0.5) # to slow down puts log_reg = self.rand.randint(0, self.log_regs - 1) phys_reg = self.regs_left_queue.get() regs = {"rl_dst": log_reg, "rp_dst": phys_reg} - rob_id = yield from self.m.put.call(regs) + rob_id = (await self.m.put.call(sim, regs)).rob_id self.to_execute_list.append((rob_id, phys_reg)) - self.retire_queue.put((regs, rob_id["rob_id"])) + self.retire_queue.put((regs, rob_id)) - def do_updates(self): - yield Passive() + async def do_updates(self, sim: TestbenchContext): while True: - while self.rand.random() < 0.5: - yield # to slow down execution + await self.random_wait_geom(sim, 0.5) # to slow down execution if len(self.to_execute_list) == 0: - yield Tick() + await sim.tick() else: idx = self.rand.randint(0, len(self.to_execute_list) - 1) rob_id, executed = self.to_execute_list.pop(idx) self.executed_list.append(executed) - yield from self.m.mark_done.call(rob_id) + await self.m.mark_done.call(sim, rob_id=rob_id, exception=0) - def do_retire(self): + async def do_retire(self, sim: TestbenchContext): cnt = 0 while True: if self.retire_queue.empty(): - self.m.retire.enable() - yield Tick() - is_ready = yield self.m.retire.adapter.done - assert is_ready == 0 # transaction should not be ready if there is nothing to retire + res = await self.m.retire.call_try(sim) + assert res is None # transaction should not be ready if there is nothing to retire else: regs, rob_id_exp = self.retire_queue.get() - results = yield from self.m.peek.call() - yield from self.m.retire.call() - phys_reg = results["rob_data"]["rp_dst"] - assert rob_id_exp == results["rob_id"] + results = await self.m.peek.call(sim) + await self.m.retire.call(sim) + phys_reg = results.rob_data.rp_dst + assert rob_id_exp == results.rob_id assert phys_reg in self.executed_list self.executed_list.remove(phys_reg) - yield Settle() - assert results["rob_data"] == regs + assert data_const_to_dict(results.rob_data) == regs self.regs_left_queue.put(phys_reg) cnt += 1 @@ -82,40 +76,38 @@ def test_single(self): self.log_regs = self.gen_params.isa.reg_cnt with self.run_simulation(m) as sim: - sim.add_process(self.gen_input) - sim.add_process(self.do_updates) - sim.add_process(self.do_retire) + sim.add_testbench(self.gen_input) + sim.add_testbench(self.do_updates, background=True) + sim.add_testbench(self.do_retire) class TestFullDoneCase(TestCaseWithSimulator): - def gen_input(self): + async def gen_input(self, sim: TestbenchContext): for _ in range(self.test_steps): log_reg = self.rand.randrange(self.log_regs) phys_reg = self.rand.randrange(self.phys_regs) - rob_id = yield from self.m.put.call(rl_dst=log_reg, rp_dst=phys_reg) + rob_id = (await self.m.put.call(sim, rl_dst=log_reg, rp_dst=phys_reg)).rob_id self.to_execute_list.append(rob_id) - def do_single_update(self): + async def do_single_update(self, sim: TestbenchContext): while len(self.to_execute_list) == 0: - yield Tick() + await sim.tick() rob_id = self.to_execute_list.pop(0) - yield from self.m.mark_done.call(rob_id) + await self.m.mark_done.call(sim, rob_id=rob_id) - def do_retire(self): + async def do_retire(self, sim: TestbenchContext): for i in range(self.test_steps - 1): - yield from self.do_single_update() + await self.do_single_update(sim) - yield from self.m.retire.call() - yield from self.do_single_update() + await self.m.retire.call(sim) + await self.do_single_update(sim) for i in range(self.test_steps - 1): - yield from self.m.retire.call() + await self.m.retire.call(sim) - yield from self.m.retire.enable() - yield Tick() - res = yield self.m.retire.adapter.done - assert res == 0 # should be disabled, since we have read all elements + res = await self.m.retire.call_try(sim) + assert res is None # since we have read all elements def test_single(self): self.rand = Random(0) @@ -130,5 +122,5 @@ def test_single(self): self.phys_regs = 2**self.gen_params.phys_regs_bits with self.run_simulation(m) as sim: - sim.add_process(self.gen_input) - sim.add_process(self.do_retire) + sim.add_testbench(self.gen_input) + sim.add_testbench(self.do_retire) diff --git a/test/frontend/test_decode_stage.py b/test/frontend/test_decode_stage.py index 8cfcb95fd..acab29abd 100644 --- a/test/frontend/test_decode_stage.py +++ b/test/frontend/test_decode_stage.py @@ -1,7 +1,7 @@ import pytest from transactron.lib import AdapterTrans, FIFO - -from transactron.testing import TestCaseWithSimulator, TestbenchIO, SimpleTestCircuit, ModuleConnector +from transactron.utils.amaranth_ext.elaboratables import ModuleConnector +from transactron.testing import TestCaseWithSimulator, TestbenchIO, SimpleTestCircuit, TestbenchContext from coreblocks.frontend.decoder.decode_stage import DecodeStage from coreblocks.params import GenParams @@ -32,10 +32,10 @@ def setup(self, fixture_initialize_testing_env): ) ) - def decode_test_proc(self): + async def decode_test_proc(self, sim: TestbenchContext): # testing an OP_IMM instruction (test copied from test_decoder.py) - yield from self.fifo_in_write.call(instr=0x02A28213) - decoded = yield from self.fifo_out_read.call() + await self.fifo_in_write.call(sim, instr=0x02A28213) + decoded = await self.fifo_out_read.call(sim) assert decoded["exec_fn"]["op_type"] == OpType.ARITHMETIC assert decoded["exec_fn"]["funct3"] == Funct3.ADD @@ -46,8 +46,8 @@ def decode_test_proc(self): assert decoded["imm"] == 42 # testing an OP instruction (test copied from test_decoder.py) - yield from self.fifo_in_write.call(instr=0x003100B3) - decoded = yield from self.fifo_out_read.call() + await self.fifo_in_write.call(sim, instr=0x003100B3) + decoded = await self.fifo_out_read.call(sim) assert decoded["exec_fn"]["op_type"] == OpType.ARITHMETIC assert decoded["exec_fn"]["funct3"] == Funct3.ADD @@ -57,8 +57,8 @@ def decode_test_proc(self): assert decoded["regs_l"]["rl_s2"] == 3 # testing an illegal - yield from self.fifo_in_write.call(instr=0x0) - decoded = yield from self.fifo_out_read.call() + await self.fifo_in_write.call(sim, instr=0x0) + decoded = await self.fifo_out_read.call(sim) assert decoded["exec_fn"]["op_type"] == OpType.EXCEPTION assert decoded["exec_fn"]["funct3"] == Funct3._EILLEGALINSTR @@ -67,8 +67,8 @@ def decode_test_proc(self): assert decoded["regs_l"]["rl_s1"] == 0 assert decoded["regs_l"]["rl_s2"] == 0 - yield from self.fifo_in_write.call(instr=0x0, access_fault=1) - decoded = yield from self.fifo_out_read.call() + await self.fifo_in_write.call(sim, instr=0x0, access_fault=1) + decoded = await self.fifo_out_read.call(sim) assert decoded["exec_fn"]["op_type"] == OpType.EXCEPTION assert decoded["exec_fn"]["funct3"] == Funct3._EINSTRACCESSFAULT @@ -79,4 +79,4 @@ def decode_test_proc(self): def test(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.decode_test_proc) + sim.add_testbench(self.decode_test_proc) diff --git a/test/frontend/test_fetch.py b/test/frontend/test_fetch.py index 33b216752..5e2406776 100644 --- a/test/frontend/test_fetch.py +++ b/test/frontend/test_fetch.py @@ -6,13 +6,20 @@ import random from amaranth import Elaboratable, Module -from amaranth.sim import Passive, Tick from coreblocks.interface.keys import FetchResumeKey from transactron.core import Method from transactron.lib import AdapterTrans, Adapter, BasicFifo +from transactron.testing.method_mock import MethodMock from transactron.utils import ModuleConnector -from transactron.testing import TestCaseWithSimulator, TestbenchIO, def_method_mock, SimpleTestCircuit, TestGen +from transactron.testing import ( + TestCaseWithSimulator, + TestbenchIO, + def_method_mock, + SimpleTestCircuit, + TestbenchContext, + ProcessContext, +) from coreblocks.frontend.fetch.fetch import FetchUnit, PredictionChecker from coreblocks.cache.iface import CacheInterface @@ -133,15 +140,12 @@ def gen_branch(self, offset: int, taken: bool): return self.add_instr(data, True, jump_offset=offset, branch_taken=taken) - def cache_process(self): - yield Passive() - + async def cache_process(self, sim: ProcessContext): while True: while len(self.input_q) == 0: - yield Tick() + await sim.tick() - while random.random() < 0.5: - yield Tick() + await self.random_wait_geom(sim, 0.5) req_addr = self.input_q.popleft() & ~(self.gen_params.fetch_block_bytes - 1) @@ -162,15 +166,24 @@ def load_or_gen_mem(addr): self.output_q.append({"fetch_block": fetch_block, "error": bad_addr}) - @def_method_mock(lambda self: self.icache.issue_req_io, enable=lambda self: len(self.input_q) < 2, sched_prio=1) + @def_method_mock( + lambda self: self.icache.issue_req_io, enable=lambda self: len(self.input_q) < 2 + ) # TODO had sched_prio def issue_req_mock(self, addr): - self.input_q.append(addr) + @MethodMock.effect + def eff(): + self.input_q.append(addr) @def_method_mock(lambda self: self.icache.accept_res_io, enable=lambda self: len(self.output_q) > 0) def accept_res_mock(self): - return self.output_q.popleft() + @MethodMock.effect + def eff(): + self.output_q.popleft() - def fetch_out_check(self): + if self.output_q: + return self.output_q[0] + + async def fetch_out_check(self, sim: TestbenchContext): while self.instr_queue: instr = self.instr_queue.popleft() @@ -178,7 +191,7 @@ def fetch_out_check(self): if not instr["rvc"]: access_fault |= instr["pc"] + 2 in self.memerr - v = yield from self.io_out.call() + v = await self.io_out.call(sim) assert v["pc"] == instr["pc"] assert v["access_fault"] == access_fault @@ -188,13 +201,13 @@ def fetch_out_check(self): assert v["instr"] == instr_data if (instr["jumps"] and (instr["branch_taken"] != v["predicted_taken"])) or access_fault: - yield from self.random_wait(5) - yield from self.fetch.stall_exception.call() - yield from self.random_wait(5) + await self.random_wait(sim, 5) + await self.fetch.stall_exception.call(sim) + await self.random_wait(sim, 5) # Empty the pipeline - yield from self.clean_fifo.call_try() - yield Tick() + await self.clean_fifo.call_try(sim) + await sim.tick() resume_pc = instr["next_pc"] if access_fault: @@ -204,13 +217,13 @@ def fetch_out_check(self): ) + self.gen_params.fetch_block_bytes # Resume the fetch unit - while (yield from self.fetch.resume_from_exception.call_try(pc=resume_pc)) is None: + while await self.fetch.resume_from_exception.call_try(sim, pc=resume_pc) is None: pass def run_sim(self): with self.run_simulation(self.m) as sim: sim.add_process(self.cache_process) - sim.add_process(self.fetch_out_check) + sim.add_testbench(self.fetch_out_check) def test_simple_no_jumps(self): for _ in range(50): @@ -390,7 +403,7 @@ def test_random(self): with self.run_simulation(self.m) as sim: sim.add_process(self.cache_process) - sim.add_process(self.fetch_out_check) + sim.add_testbench(self.fetch_out_check) @dataclass(frozen=True) @@ -424,8 +437,9 @@ def setup(self, fixture_initialize_testing_env): self.m = SimpleTestCircuit(PredictionChecker(self.gen_params)) - def check( + async def check( self, + sim: TestbenchContext, pc: int, block_cross: bool, predecoded: list[tuple[CfiType, int]], @@ -434,7 +448,7 @@ def check( cfi_type: CfiType, cfi_target: Optional[int], valid_mask: int = -1, - ) -> TestGen[CheckerResult]: + ) -> CheckerResult: # Fill the array with non-CFI instructions for _ in range(self.gen_params.fetch_width - len(predecoded)): predecoded.append((CfiType.INVALID, 0)) @@ -457,7 +471,8 @@ def check( instr_valid = (((1 << self.gen_params.fetch_width) - 1) << instr_start) & valid_mask - res = yield from self.m.check.call( + res = await self.m.check.call( + sim, fb_addr=pc >> self.gen_params.fetch_block_bytes_log, instr_block_cross=block_cross, instr_valid=instr_valid, @@ -493,46 +508,46 @@ def test_no_misprediction(self): instr_width = self.gen_params.min_instr_width_bytes fetch_width = self.gen_params.fetch_width - def proc(): + async def proc(sim: TestbenchContext): # No CFI at all - ret = yield from self.check(0x100, False, [], 0, 0, CfiType.INVALID, None) + ret = await self.check(sim, 0x100, False, [], 0, 0, CfiType.INVALID, None) self.assert_resp(ret, mispredicted=False) # There is one forward branch that we didn't predict - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, 100)], 0, 0, CfiType.INVALID, None) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0, 0, CfiType.INVALID, None) self.assert_resp(ret, mispredicted=False) # There are many forward branches that we didn't predict - ret = yield from self.check( - 0x100, False, [(CfiType.BRANCH, 100)] * fetch_width, 0, 0, CfiType.INVALID, None + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, 100)] * fetch_width, 0, 0, CfiType.INVALID, None ) self.assert_resp(ret, mispredicted=False) # There is a predicted JAL instr - ret = yield from self.check(0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 0x100 + 100) + ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 0x100 + 100) self.assert_resp(ret, mispredicted=False) # There is a predicted JALR instr - the predecoded offset can now be anything - ret = yield from self.check(0x100, False, [(CfiType.JALR, 200)], 0, 0, CfiType.JALR, 0x100 + 100) + ret = await self.check(sim, 0x100, False, [(CfiType.JALR, 200)], 0, 0, CfiType.JALR, 0x100 + 100) self.assert_resp(ret, mispredicted=False) # There is a forward taken-predicted branch - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 0x100 + 100) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 0x100 + 100) self.assert_resp(ret, mispredicted=False) # There is a backward taken-predicted branch - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100) self.assert_resp(ret, mispredicted=False) # Branch located between two fetch blocks if self.with_rvc: - ret = yield from self.check( - 0x100, True, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100 - 2 + ret = await self.check( + sim, 0x100, True, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100 - 2 ) self.assert_resp(ret, mispredicted=False) # One branch predicted as not taken - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.INVALID, 0) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.INVALID, 0) self.assert_resp(ret, mispredicted=False) # Now tests for fetch blocks with multiple instructions @@ -540,7 +555,8 @@ def proc(): return # Predicted taken branch as the second instruction - ret = yield from self.check( + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], @@ -552,13 +568,14 @@ def proc(): self.assert_resp(ret, mispredicted=False) # Predicted, but not taken branch as the second instruction - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b10, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b10, 0, CfiType.INVALID, 0 ) self.assert_resp(ret, mispredicted=False) if self.with_rvc: - ret = yield from self.check( + ret = await self.check( + sim, 0x100, True, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], @@ -569,7 +586,8 @@ def proc(): ) self.assert_resp(ret, mispredicted=False) - ret = yield from self.check( + ret = await self.check( + sim, 0x100, True, [(CfiType.JAL, 100), (CfiType.JAL, -100)], @@ -582,15 +600,16 @@ def proc(): self.assert_resp(ret, mispredicted=False) # Two branches with all possible combintations taken/not-taken - ret = yield from self.check( - 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.INVALID, 0 ) self.assert_resp(ret, mispredicted=False) - ret = yield from self.check( - 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.BRANCH, 0x100 - 100 + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.BRANCH, 0x100 - 100 ) self.assert_resp(ret, mispredicted=False) - ret = yield from self.check( + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], @@ -602,17 +621,25 @@ def proc(): self.assert_resp(ret, mispredicted=False) # JAL at the beginning, but we start from the second instruction - ret = yield from self.check(0x100 + instr_width, False, [(CfiType.JAL, -100)], 0b0, 0, CfiType.INVALID, 0) + ret = await self.check(sim, 0x100 + instr_width, False, [(CfiType.JAL, -100)], 0b0, 0, CfiType.INVALID, 0) self.assert_resp(ret, mispredicted=False) # JAL and a forward branch that we didn't predict - ret = yield from self.check( - 0x100 + instr_width, False, [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], 0b00, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, + 0x100 + instr_width, + False, + [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], + 0b00, + 0, + CfiType.INVALID, + 0, ) self.assert_resp(ret, mispredicted=False) # two JAL instructions, but we start from the second one - ret = yield from self.check( + ret = await self.check( + sim, 0x100 + instr_width, False, [(CfiType.JAL, -100), (CfiType.JAL, 100)], @@ -624,7 +651,8 @@ def proc(): self.assert_resp(ret, mispredicted=False) # JAL and a branch, but we start from the second instruction - ret = yield from self.check( + ret = await self.check( + sim, 0x100 + instr_width, False, [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], @@ -636,24 +664,24 @@ def proc(): self.assert_resp(ret, mispredicted=False) with self.run_simulation(self.m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) def test_preceding_redirection(self): instr_width = self.gen_params.min_instr_width_bytes fetch_width = self.gen_params.fetch_width - def proc(): + async def proc(sim: TestbenchContext): # No prediction was made, but there is a JAL at the beginning - ret = yield from self.check(0x100, False, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None) + ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 0x20) # The same, but the jump is between two fetch blocks if self.with_rvc: - ret = yield from self.check(0x100, True, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None) + ret = await self.check(sim, 0x100, True, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 0x20 - 2) # Not predicted backward branch - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, -100)], 0b0, 0, CfiType.INVALID, 0) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b0, 0, CfiType.INVALID, 0) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) # Now tests for fetch blocks with multiple instructions @@ -661,7 +689,8 @@ def proc(): return # We predicted the branch on the second instruction, but there's a JAL on the first one. - ret = yield from self.check( + ret = await self.check( + sim, 0x100, False, [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], @@ -673,7 +702,8 @@ def proc(): self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) # We predicted the branch on the second instruction, but there's a JALR on the first one. - ret = yield from self.check( + ret = await self.check( + sim, 0x100, False, [(CfiType.JALR, -100), (CfiType.BRANCH, 100)], @@ -685,7 +715,8 @@ def proc(): self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=0) # We predicted the branch on the second instruction, but there's a backward on the first one. - ret = yield from self.check( + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], @@ -697,31 +728,32 @@ def proc(): self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) # Unpredicted backward branch as the second instruction - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b00, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b00, 0, CfiType.INVALID, 0 ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width - 100 ) # Unpredicted JAL as the second instruction - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b00, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b00, 0, CfiType.INVALID, 0 ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100 ) # Unpredicted JALR as the second instruction - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.JALR, 100)], 0b00, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JALR, 100)], 0b00, 0, CfiType.INVALID, 0 ) self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=1) if fetch_width < 3: return - ret = yield from self.check( + ret = await self.check( + sim, 0x100 + instr_width, False, [(CfiType.JAL, -100), (CfiType.INVALID, 100), (CfiType.JAL, 100)], @@ -735,94 +767,101 @@ def proc(): ) with self.run_simulation(self.m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) def test_mispredicted_cfi_type(self): instr_width = self.gen_params.min_instr_width_bytes fetch_width = self.gen_params.fetch_width fb_bytes = self.gen_params.fetch_block_bytes - def proc(): + async def proc(sim: TestbenchContext): # We predicted a JAL, but in fact there is a non-CFI instruction - ret = yield from self.check(0x100, False, [(CfiType.INVALID, 0)], 0, 0, CfiType.JAL, 100) + ret = await self.check(sim, 0x100, False, [(CfiType.INVALID, 0)], 0, 0, CfiType.JAL, 100) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes ) # We predicted a JAL, but in fact there is a branch - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, -100)], 0, 0, CfiType.JAL, 100) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0, 0, CfiType.JAL, 100) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) # We predicted a JAL, but in fact there is a JALR instruction - ret = yield from self.check(0x100, False, [(CfiType.JALR, -100)], 0, 0, CfiType.JAL, 100) + ret = await self.check(sim, 0x100, False, [(CfiType.JALR, -100)], 0, 0, CfiType.JAL, 100) self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=0) # We predicted a branch, but in fact there is a JAL - ret = yield from self.check(0x100, False, [(CfiType.JAL, -100)], 0b1, 0, CfiType.BRANCH, 100) + ret = await self.check(sim, 0x100, False, [(CfiType.JAL, -100)], 0b1, 0, CfiType.BRANCH, 100) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) if fetch_width < 2: return # There is a branch and a non-CFI, but we predicted two branches - ret = yield from self.check( - 0x100, False, [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], 0b11, 1, CfiType.BRANCH, 100 + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], 0b11, 1, CfiType.BRANCH, 100 ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes ) # The same as above, but we start from the second instruction - ret = yield from self.check( - 0x100 + instr_width, False, [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], 0b11, 1, CfiType.BRANCH, 100 + ret = await self.check( + sim, + 0x100 + instr_width, + False, + [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], + 0b11, + 1, + CfiType.BRANCH, + 100, ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes ) with self.run_simulation(self.m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) def test_mispredicted_cfi_target(self): instr_width = self.gen_params.min_instr_width_bytes fetch_width = self.gen_params.fetch_width - def proc(): + async def proc(sim: TestbenchContext): # We predicted a wrong JAL target - ret = yield from self.check(0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 200) + ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 200) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100) # We predicted a wrong branch target - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 200) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 200) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100) # We didn't provide the branch target - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, None) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, None) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100) # We predicted a wrong JAL target that is between two fetch blocks if self.with_rvc: - ret = yield from self.check(0x100, True, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 300) + ret = await self.check(sim, 0x100, True, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 300) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100 - 2) if fetch_width < 2: return # The second instruction is a branch without the target - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, 100)], 0b10, 1, CfiType.BRANCH, None + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, 100)], 0b10, 1, CfiType.BRANCH, None ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100 ) # The second instruction is a JAL with a wrong target - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b10, 1, CfiType.JAL, 200 + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b10, 1, CfiType.JAL, 200 ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100 ) with self.run_simulation(self.m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) diff --git a/test/frontend/test_instr_decoder.py b/test/frontend/test_instr_decoder.py index 1830f1063..f7fc44c30 100644 --- a/test/frontend/test_instr_decoder.py +++ b/test/frontend/test_instr_decoder.py @@ -1,6 +1,6 @@ from amaranth.sim import * -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, TestbenchContext from coreblocks.params import * from coreblocks.params.configurations import test_core_config @@ -181,64 +181,63 @@ def setup_method(self): self.cnt = 1 def do_test(self, tests: list[InstrTest]): - def process(): + async def process(sim: TestbenchContext): for test in tests: - yield self.decoder.instr.eq(test.encoding) - yield Settle() + sim.set(self.decoder.instr, test.encoding) - assert (yield self.decoder.illegal) == test.illegal + assert sim.get(self.decoder.illegal) == test.illegal if test.illegal: return - assert (yield self.decoder.opcode) == test.opcode + assert sim.get(self.decoder.opcode) == test.opcode if test.funct3 is not None: - assert (yield self.decoder.funct3) == test.funct3 - assert (yield self.decoder.funct3_v) == (test.funct3 is not None) + assert sim.get(self.decoder.funct3) == test.funct3 + assert sim.get(self.decoder.funct3_v) == (test.funct3 is not None) if test.funct7 is not None: - assert (yield self.decoder.funct7) == test.funct7 - assert (yield self.decoder.funct7_v) == (test.funct7 is not None) + assert sim.get(self.decoder.funct7) == test.funct7 + assert sim.get(self.decoder.funct7_v) == (test.funct7 is not None) if test.funct12 is not None: - assert (yield self.decoder.funct12) == test.funct12 - assert (yield self.decoder.funct12_v) == (test.funct12 is not None) + assert sim.get(self.decoder.funct12) == test.funct12 + assert sim.get(self.decoder.funct12_v) == (test.funct12 is not None) if test.rd is not None: - assert (yield self.decoder.rd) == test.rd - assert (yield self.decoder.rd_v) == (test.rd is not None) + assert sim.get(self.decoder.rd) == test.rd + assert sim.get(self.decoder.rd_v) == (test.rd is not None) if test.rs1 is not None: - assert (yield self.decoder.rs1) == test.rs1 - assert (yield self.decoder.rs1_v) == (test.rs1 is not None) + assert sim.get(self.decoder.rs1) == test.rs1 + assert sim.get(self.decoder.rs1_v) == (test.rs1 is not None) if test.rs2 is not None: - assert (yield self.decoder.rs2) == test.rs2 - assert (yield self.decoder.rs2_v) == (test.rs2 is not None) + assert sim.get(self.decoder.rs2) == test.rs2 + assert sim.get(self.decoder.rs2_v) == (test.rs2 is not None) if test.imm is not None: if test.csr is not None: # in CSR instruction additional fields are passed in unused bits of imm field - assert (yield self.decoder.imm.as_signed() & ((2**5) - 1)) == test.imm + assert sim.get(self.decoder.imm.as_signed() & ((2**5) - 1)) == test.imm else: - assert (yield self.decoder.imm.as_signed()) == test.imm + assert sim.get(self.decoder.imm.as_signed()) == test.imm if test.succ is not None: - assert (yield self.decoder.succ) == test.succ + assert sim.get(self.decoder.succ) == test.succ if test.pred is not None: - assert (yield self.decoder.pred) == test.pred + assert sim.get(self.decoder.pred) == test.pred if test.fm is not None: - assert (yield self.decoder.fm) == test.fm + assert sim.get(self.decoder.fm) == test.fm if test.csr is not None: - assert (yield self.decoder.csr) == test.csr + assert sim.get(self.decoder.csr) == test.csr - assert (yield self.decoder.optype) == test.op + assert sim.get(self.decoder.optype) == test.op with self.run_simulation(self.decoder) as sim: - sim.add_process(process) + sim.add_testbench(process) def test_i(self): self.do_test(self.DECODER_TESTS_I) @@ -280,14 +279,13 @@ def test_e(self): self.gen_params = GenParams(test_core_config.replace(embedded=True, _implied_extensions=Extension.E)) self.decoder = InstrDecoder(self.gen_params) - def process(): + async def process(sim: TestbenchContext): for encoding, illegal in self.E_TEST: - yield self.decoder.instr.eq(encoding) - yield Settle() - assert (yield self.decoder.illegal) == illegal + sim.set(self.decoder.instr, encoding) + assert sim.get(self.decoder.illegal) == illegal with self.run_simulation(self.decoder) as sim: - sim.add_process(process) + sim.add_testbench(process) class TestEncodingUniqueness(TestCase): diff --git a/test/frontend/test_rvc.py b/test/frontend/test_rvc.py index f1690f8dd..53dcaebd0 100644 --- a/test/frontend/test_rvc.py +++ b/test/frontend/test_rvc.py @@ -1,6 +1,5 @@ from parameterized import parameterized_class -from amaranth.sim import Settle, Tick from amaranth import * from coreblocks.frontend.decoder.rvc import InstrDecompress @@ -9,7 +8,7 @@ from coreblocks.params.configurations import test_core_config from transactron.utils import ValueLike -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, TestbenchContext COMMON_TESTS = [ # Illegal instruction @@ -283,22 +282,18 @@ def test(self): ) self.m = InstrDecompress(self.gen_params) - def process(): - illegal = Signal(32) - yield illegal.eq(IllegalInstr()) + async def process(sim: TestbenchContext): + illegal = Const.cast(IllegalInstr()).value for instr_in, instr_out in self.test_cases: - yield self.m.instr_in.eq(instr_in) - expected = Signal(32) - yield expected.eq(instr_out) - yield Settle() + sim.set(self.m.instr_in, instr_in) + expected = Const.cast(instr_out).value - if (yield expected) == (yield illegal): - yield expected.eq(instr_in) # for exception handling - yield Settle() + if expected == illegal: + expected = instr_in # for exception handling - assert (yield self.m.instr_out) == (yield expected) - yield Tick() + assert sim.get(self.m.instr_out) == expected + await sim.tick() with self.run_simulation(self.m) as sim: - sim.add_process(process) + sim.add_testbench(process) diff --git a/test/func_blocks/csr/test_csr.py b/test/func_blocks/csr/test_csr.py index da39d36b2..6fa8c95e7 100644 --- a/test/func_blocks/csr/test_csr.py +++ b/test/func_blocks/csr/test_csr.py @@ -1,5 +1,5 @@ from amaranth import * -from random import random +import random from transactron.lib import Adapter from transactron.core.tmodule import TModule @@ -17,6 +17,7 @@ CSRInstancesKey, ) from coreblocks.arch.isa_consts import PrivilegeLevel +from transactron.lib.adapters import AdapterTrans from transactron.utils.dependencies import DependencyContext from transactron.testing import * @@ -77,8 +78,8 @@ def make_csr(number: int): class TestCSRUnit(TestCaseWithSimulator): - def gen_expected_out(self, op, rd, rs1, operand_val, csr): - exp_read = {"rp_dst": rd, "result": (yield self.dut.csr[csr].value)} + def gen_expected_out(self, sim: TestbenchContext, op: Funct3, rd: int, rs1: int, operand_val: int, csr: int): + exp_read = {"rp_dst": rd, "result": sim.get(self.dut.csr[csr].value)} rs1_val = {"rp_s1": rs1, "value": operand_val} exp_write = {} @@ -89,11 +90,11 @@ def gen_expected_out(self, op, rd, rs1, operand_val, csr): elif (op == Funct3.CSRRS and rs1) or op == Funct3.CSRRSI: exp_write = {"csr": csr, "value": exp_read["result"] | operand_val} else: - exp_write = {"csr": csr, "value": (yield self.dut.csr[csr].value)} + exp_write = {"csr": csr, "value": sim.get(self.dut.csr[csr].value)} return {"exp_read": exp_read, "exp_write": exp_write, "rs1": rs1_val} - def generate_instruction(self): + def generate_instruction(self, sim: TestbenchContext): ops = [ Funct3.CSRRW, Funct3.CSRRC, @@ -113,7 +114,7 @@ def generate_instruction(self): operand_val = imm if imm_op else rs1_val csr = random.choice(list(self.dut.csr.keys())) - exp = yield from self.gen_expected_out(op, rd, rs1, operand_val, csr) + exp = self.gen_expected_out(sim, op, rd, rs1, operand_val, csr) value_available = random.random() < 0.2 @@ -130,36 +131,38 @@ def generate_instruction(self): "exp": exp, } - def process_test(self): - yield from self.dut.fetch_resume.enable() - yield from self.dut.exception_report.enable() + async def process_test(self, sim: TestbenchContext): + self.dut.fetch_resume.enable(sim) + self.dut.exception_report.enable(sim) for _ in range(self.cycles): - yield from self.random_wait_geom() + await self.random_wait_geom(sim) - op = yield from self.generate_instruction() + op = self.generate_instruction(sim) - yield from self.dut.select.call() + await self.dut.select.call(sim) - yield from self.dut.insert.call(rs_data=op["instr"]) + await self.dut.insert.call(sim, rs_data=op["instr"]) - yield from self.random_wait_geom() + await self.random_wait_geom(sim) if op["exp"]["rs1"]["rp_s1"]: - yield from self.dut.update.call(reg_id=op["exp"]["rs1"]["rp_s1"], reg_val=op["exp"]["rs1"]["value"]) + await self.dut.update.call(sim, reg_id=op["exp"]["rs1"]["rp_s1"], reg_val=op["exp"]["rs1"]["value"]) - yield from self.random_wait_geom() - yield from self.dut.precommit.method_handle( - function=lambda rob_id: {"side_fx": 1}, validate_arguments=lambda rob_id: True - ) + await self.random_wait_geom(sim) + # TODO: this is a hack, a real method mock should be used + for _, r in self.dut.precommit.adapter.validators: # type: ignore + sim.set(r, 1) + self.dut.precommit.call_init(sim, side_fx=1) # TODO: sensible precommit handling - yield from self.random_wait_geom() - res = yield from self.dut.accept.call() + await self.random_wait_geom(sim) + res, resume_res = await CallTrigger(sim).call(self.dut.accept).sample(self.dut.fetch_resume).until_done() + self.dut.precommit.disable(sim) - assert self.dut.fetch_resume.done() - assert res["rp_dst"] == op["exp"]["exp_read"]["rp_dst"] + assert res is not None and resume_res is not None + assert res.rp_dst == op["exp"]["exp_read"]["rp_dst"] if op["exp"]["exp_read"]["rp_dst"]: - assert res["result"] == op["exp"]["exp_read"]["result"] - assert (yield self.dut.csr[op["exp"]["exp_write"]["csr"]].value) == op["exp"]["exp_write"]["value"] - assert res["exception"] == 0 + assert res.result == op["exp"]["exp_read"]["result"] + assert sim.get(self.dut.csr[op["exp"]["exp_write"]["csr"]].value) == op["exp"]["exp_write"]["value"] + assert res.exception == 0 def test_randomized(self): self.gen_params = GenParams(test_core_config) @@ -171,7 +174,7 @@ def test_randomized(self): self.dut = CSRUnitTestCircuit(self.gen_params, self.csr_count) with self.run_simulation(self.dut) as sim: - sim.add_process(self.process_test) + sim.add_testbench(self.process_test) exception_csr_numbers = [ 0xCC0, # read_only @@ -179,21 +182,22 @@ def test_randomized(self): 0x7FE, # missing priv ] - def process_exception_test(self): - yield from self.dut.fetch_resume.enable() - yield from self.dut.exception_report.enable() + async def process_exception_test(self, sim: TestbenchContext): + self.dut.fetch_resume.enable(sim) + self.dut.exception_report.enable(sim) for csr in self.exception_csr_numbers: if csr == 0x7FE: - yield from self.dut.priv_io.call(data=PrivilegeLevel.USER) + await self.dut.priv_io.call(sim, data=PrivilegeLevel.USER) else: - yield from self.dut.priv_io.call(data=PrivilegeLevel.MACHINE) + await self.dut.priv_io.call(sim, data=PrivilegeLevel.MACHINE) - yield from self.random_wait_geom() + await self.random_wait_geom(sim) - yield from self.dut.select.call() + await self.dut.select.call(sim) - csr_rob_id = random.randrange(2**self.gen_params.rob_entries_bits) - yield from self.dut.insert.call( + rob_id = random.randrange(2**self.gen_params.rob_entries_bits) + await self.dut.insert.call( + sim, rs_data={ "exec_fn": {"op_type": OpType.CSR_REG, "funct3": Funct3.CSRRW, "funct7": 0}, "rp_s1": 0, @@ -202,24 +206,25 @@ def process_exception_test(self): "rp_dst": 2, "imm": 0, "csr": csr, - "rob_id": csr_rob_id, - } + "rob_id": rob_id, + }, ) - yield from self.random_wait_geom() - yield from self.dut.precommit.method_handle( - function=lambda rob_id: {"side_fx": 1}, validate_arguments=lambda rob_id: rob_id == csr_rob_id - ) + await self.random_wait_geom(sim) + # TODO: this is a hack, a real method mock should be used + for _, r in self.dut.precommit.adapter.validators: # type: ignore + sim.set(r, 1) + self.dut.precommit.call_init(sim, side_fx=1) - yield from self.random_wait_geom() - res = yield from self.dut.accept.call() + await self.random_wait_geom(sim) + res, report = await CallTrigger(sim).call(self.dut.accept).sample(self.dut.exception_report).until_done() + self.dut.precommit.disable(sim) assert res["exception"] == 1 - report = yield from self.dut.exception_report.call_result() assert report is not None - assert isinstance(report, dict) - report.pop("mtval") # mtval tested in mtval.asm test - assert {"rob_id": csr_rob_id, "cause": ExceptionCause.ILLEGAL_INSTRUCTION, "pc": 0} == report + report_dict = data_const_to_dict(report) + report_dict.pop("mtval") # mtval tested in mtval.asm test + assert {"rob_id": rob_id, "cause": ExceptionCause.ILLEGAL_INSTRUCTION, "pc": 0} == report_dict def test_exception(self): self.gen_params = GenParams(test_core_config) @@ -228,13 +233,13 @@ def test_exception(self): self.dut = CSRUnitTestCircuit(self.gen_params, 0, only_legal=False) with self.run_simulation(self.dut) as sim: - sim.add_process(self.process_exception_test) + sim.add_testbench(self.process_exception_test) class TestCSRRegister(TestCaseWithSimulator): - def randomized_process_test(self): + async def randomized_process_test(self, sim: TestbenchContext): # always enabled - yield from self.dut.read.enable() + self.dut.read.enable(sim) previous_data = 0 for _ in range(self.cycles): @@ -246,7 +251,7 @@ def randomized_process_test(self): if random.random() < 0.9: write = True exp_write_data = random.randint(0, 2**self.gen_params.isa.xlen - 1) - yield from self.dut.write.call_init(data=exp_write_data) + self.dut.write.call_init(sim, data=exp_write_data) if random.random() < 0.3: fu_write = True @@ -255,33 +260,32 @@ def randomized_process_test(self): exp_write_data = (write_arg & ~self.ro_mask) | ( (exp_write_data if exp_write_data is not None else previous_data) & self.ro_mask ) - yield from self.dut._fu_write.call_init(data=write_arg) + self.dut._fu_write.call_init(sim, data=write_arg) if random.random() < 0.2: fu_read = True - yield from self.dut._fu_read.enable() + self.dut._fu_read.call_init(sim) - yield Tick() - yield Settle() + await sim.tick() exp_read_data = exp_write_data if fu_write or write else previous_data if fu_read: # in CSRUnit this call is called before write and returns previous result - assert (yield from self.dut._fu_read.call_result()) == {"data": exp_read_data} + assert data_const_to_dict(self.dut._fu_read.get_call_result(sim)) == {"data": exp_read_data} - assert (yield from self.dut.read.call_result()) == { + assert data_const_to_dict(self.dut.read.get_call_result(sim)) == { "data": exp_read_data, "read": int(fu_read), "written": int(fu_write), } - read_result = yield from self.dut.read.call_result() + read_result = self.dut.read.get_call_result(sim) assert read_result is not None - previous_data = read_result["data"] + previous_data = read_result.data - yield from self.dut._fu_read.disable() - yield from self.dut._fu_write.disable() - yield from self.dut.write.disable() + self.dut._fu_read.disable(sim) + self.dut._fu_write.disable(sim) + self.dut.write.disable(sim) def test_randomized(self): self.gen_params = GenParams(test_core_config) @@ -293,15 +297,15 @@ def test_randomized(self): self.dut = SimpleTestCircuit(CSRRegister(0, self.gen_params, ro_bits=self.ro_mask)) with self.run_simulation(self.dut) as sim: - sim.add_process(self.randomized_process_test) + sim.add_testbench(self.randomized_process_test) - def filtermap_process_test(self): + async def filtermap_process_test(self, sim: TestbenchContext): prev_value = 0 for _ in range(50): input = random.randrange(0, 2**34) - yield from self.dut._fu_write.call({"data": input}) - output = (yield from self.dut._fu_read.call())["data"] + await self.dut._fu_write.call(sim, data=input) + output = (await self.dut._fu_read.call(sim))["data"] expected = prev_value if input & 1: @@ -341,43 +345,46 @@ def write_filtermap(m: TModule, v: Value): ro_bits=(1 << 32), fu_read_map=lambda _, v: v << 1, fu_write_filtermap=write_filtermap, - ) + ), ) with self.run_simulation(self.dut) as sim: - sim.add_process(self.filtermap_process_test) - - def comb_process_test(self): - yield from self.dut.read.enable() - yield from self.dut.read_comb.enable() - yield from self.dut._fu_read.enable() - - yield from self.dut._fu_write.call_init({"data": 0xFFFF}) - yield from self.dut._fu_write.call_do() - assert (yield from self.dut.read_comb.call_result())["data"] == 0xFFFF - assert (yield from self.dut._fu_read.call_result())["data"] == 0xAB - yield Tick() - assert (yield from self.dut.read.call_result())["data"] == 0xFFFB - assert (yield from self.dut._fu_read.call_result())["data"] == 0xFFFB - yield Tick() - - yield from self.dut._fu_write.call_init({"data": 0x0FFF}) - yield from self.dut.write.call_init({"data": 0xAAAA}) - yield from self.dut._fu_write.call_do() - yield from self.dut.write.call_do() - assert (yield from self.dut.read_comb.call_result()) == {"data": 0x0FFF, "read": 1, "written": 1} - yield Tick() - assert (yield from self.dut._fu_read.call_result())["data"] == 0xAAAA - yield Tick() + sim.add_testbench(self.filtermap_process_test) + + async def comb_process_test(self, sim: TestbenchContext): + self.dut.read.enable(sim) + self.dut.read_comb.enable(sim) + self.dut._fu_read.enable(sim) + + self.dut._fu_write.call_init(sim, data=0xFFFF) + while self.dut._fu_write.get_call_result(sim) is None: + await sim.tick() + assert self.dut.read_comb.get_call_result(sim).data == 0xFFFF + assert self.dut._fu_read.get_call_result(sim).data == 0xAB + await sim.tick() + assert self.dut.read.get_call_result(sim)["data"] == 0xFFFB + assert self.dut._fu_read.get_call_result(sim)["data"] == 0xFFFB + await sim.tick() + + self.dut._fu_write.call_init(sim, data=0x0FFF) + self.dut.write.call_init(sim, data=0xAAAA) + while self.dut._fu_write.get_call_result(sim) is None or self.dut.write.get_call_result(sim) is None: + await sim.tick() + assert data_const_to_dict(self.dut.read_comb.get_call_result(sim)) == {"data": 0x0FFF, "read": 1, "written": 1} + await sim.tick() + assert self.dut._fu_read.get_call_result(sim).data == 0xAAAA + await sim.tick() # single cycle - yield from self.dut._fu_write.call_init({"data": 0x0BBB}) - yield from self.dut._fu_write.call_do() - update_val = (yield from self.dut.read_comb.call_result())["data"] | 0xD000 - yield from self.dut.write.call_init({"data": update_val}) - yield from self.dut.write.call_do() - yield Tick() - assert (yield from self.dut._fu_read.call_result())["data"] == 0xDBBB + self.dut._fu_write.call_init(sim, data=0x0BBB) + while self.dut._fu_write.get_call_result(sim) is None: + await sim.tick() + update_val = self.dut.read_comb.get_call_result(sim).data | 0xD000 + self.dut.write.call_init(sim, data=update_val) + while self.dut.write.get_call_result(sim) is None: + await sim.tick() + await sim.tick() + assert self.dut._fu_read.get_call_result(sim).data == 0xDBBB def test_comb(self): gen_params = GenParams(test_core_config) @@ -387,4 +394,4 @@ def test_comb(self): self.dut = SimpleTestCircuit(CSRRegister(None, gen_params, ro_bits=0b1111, fu_write_priority=False, init=0xAB)) with self.run_simulation(self.dut) as sim: - sim.add_process(self.comb_process_test) + sim.add_testbench(self.comb_process_test) diff --git a/test/func_blocks/fu/common/test_rs.py b/test/func_blocks/fu/common/test_rs.py index 222041a2a..7d311dede 100644 --- a/test/func_blocks/fu/common/test_rs.py +++ b/test/func_blocks/fu/common/test_rs.py @@ -2,15 +2,14 @@ from collections import deque from parameterized import parameterized_class -from amaranth.sim import Settle, Tick - -from transactron.testing import TestCaseWithSimulator, get_outputs, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit, TestbenchContext from coreblocks.func_blocks.fu.common.rs import RS, RSBase from coreblocks.func_blocks.fu.common.fifo_rs import FifoRS from coreblocks.params import * from coreblocks.params.configurations import test_core_config from coreblocks.arch import OpType +from transactron.testing.functions import data_const_to_dict def create_check_list(rs_entries_bits: int, insert_list: list[dict]) -> list[dict]: @@ -35,7 +34,7 @@ def create_data_list(gen_params: GenParams, count: int): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": k, "s2_val": k, @@ -75,35 +74,35 @@ def test_rs(self): self.finished = False with self.run_simulation(self.m) as sim: - sim.add_process(self.select_process) - sim.add_process(self.insert_process) - sim.add_process(self.update_process) - sim.add_process(self.take_process) + sim.add_testbench(self.select_process) + sim.add_testbench(self.insert_process) + sim.add_testbench(self.update_process) + sim.add_testbench(self.take_process) - def select_process(self): + async def select_process(self, sim: TestbenchContext): for k in range(len(self.data_list)): - rs_entry_id = (yield from self.m.select.call())["rs_entry_id"] + rs_entry_id = (await self.m.select.call(sim)).rs_entry_id self.select_queue.appendleft(rs_entry_id) self.rs_entries[rs_entry_id] = k - def insert_process(self): + async def insert_process(self, sim: TestbenchContext): for data in self.data_list: - yield Settle() # so that select_process can insert into the queue + await sim.delay(1e-9) # so that select_process can insert into the queue while not self.select_queue: - yield Tick() - yield Settle() + await sim.tick() + await sim.delay(1e-9) rs_entry_id = self.select_queue.pop() - yield from self.m.insert.call({"rs_entry_id": rs_entry_id, "rs_data": data}) + await self.m.insert.call(sim, rs_entry_id=rs_entry_id, rs_data=data) if data["rp_s1"]: self.regs_to_update.add(data["rp_s1"]) if data["rp_s2"]: self.regs_to_update.add(data["rp_s2"]) - def update_process(self): + async def update_process(self, sim: TestbenchContext): while not self.finished: - yield Settle() # so that insert_process can insert into the set + await sim.delay(1e-9) # so that insert_process can insert into the set if not self.regs_to_update: - yield Tick() + await sim.tick() continue reg_id = random.choice(list(self.regs_to_update)) self.regs_to_update.discard(reg_id) @@ -115,29 +114,26 @@ def update_process(self): if self.data_list[k]["rp_s2"] == reg_id: self.data_list[k]["rp_s2"] = 0 self.data_list[k]["s2_val"] = reg_val - yield from self.m.update.call(reg_id=reg_id, reg_val=reg_val) + await self.m.update.call(sim, reg_id=reg_id, reg_val=reg_val) - def take_process(self): + async def take_process(self, sim: TestbenchContext): taken: set[int] = set() - yield from self.m.get_ready_list[0].call_init() - yield Settle() + self.m.get_ready_list[0].call_init(sim) for k in range(len(self.data_list)): - yield Settle() - while not (yield from self.m.get_ready_list[0].done()): - yield Tick() - ready_list = (yield from self.m.get_ready_list[0].call_result())["ready_list"] + while not self.m.get_ready_list[0].get_done(sim): + await sim.tick() + ready_list = (self.m.get_ready_list[0].get_call_result(sim)).ready_list possible_ids = [i for i in range(2**self.rs_entries_bits) if ready_list & (1 << i)] - if not possible_ids: - yield Tick() - continue + while not possible_ids: + await sim.tick() rs_entry_id = random.choice(possible_ids) k = self.rs_entries[rs_entry_id] taken.add(k) test_data = dict(self.data_list[k]) del test_data["rp_s1"] del test_data["rp_s2"] - data = yield from self.m.take.call(rs_entry_id=rs_entry_id) - assert data == test_data + data = await self.m.take.call(sim, rs_entry_id=rs_entry_id) + assert data_const_to_dict(data) == test_data assert taken == set(range(len(self.data_list))) self.finished = True @@ -158,7 +154,7 @@ def test_insert(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -171,20 +167,18 @@ def test_insert(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # After each insert, entry should be marked as full for index, record in enumerate(self.insert_list): - assert (yield self.m._dut.data[index].rec_full) == 0 - yield from self.m.insert.call(record) - yield Settle() - assert (yield self.m._dut.data[index].rec_full) == 1 - yield Settle() + assert sim.get(self.m._dut.data[index].rec_full) == 0 + await self.m.insert.call(sim, record) + assert sim.get(self.m._dut.data[index].rec_full) == 1 # Check data integrity for expected, record in zip(self.check_list, self.m._dut.data): - assert expected == (yield from get_outputs(record)) + assert expected == data_const_to_dict(sim.get(record)) class TestRSMethodSelect(TestCaseWithSimulator): @@ -203,7 +197,7 @@ def test_select(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -216,38 +210,33 @@ def test_select(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # In the beginning the select method should be ready and id should be selectable for index, record in enumerate(self.insert_list): - assert (yield self.m._dut.select.ready) == 1 - assert (yield from self.m.select.call())["rs_entry_id"] == index - yield Settle() - assert (yield self.m._dut.data[index].rec_reserved) == 1 - yield from self.m.insert.call(record) - yield Settle() + assert sim.get(self.m._dut.select.ready) == 1 + assert (await self.m.select.call(sim)).rs_entry_id == index + assert sim.get(self.m._dut.data[index].rec_reserved) == 1 + await self.m.insert.call(sim, record) # Check if RS state is as expected for expected, record in zip(self.check_list, self.m._dut.data): - assert (yield record.rec_full) == expected["rec_full"] - assert (yield record.rec_reserved) == expected["rec_reserved"] + assert sim.get(record.rec_full) == expected["rec_full"] + assert sim.get(record.rec_reserved) == expected["rec_reserved"] # Reserve the last entry, then select ready should be false - assert (yield self.m._dut.select.ready) == 1 - assert (yield from self.m.select.call())["rs_entry_id"] == 3 - yield Settle() - assert (yield self.m._dut.select.ready) == 0 + assert sim.get(self.m._dut.select.ready) == 1 + assert (await self.m.select.call(sim)).rs_entry_id == 3 + assert sim.get(self.m._dut.select.ready) == 0 # After take, select ready should be true, with 0 index returned - yield from self.m.take.call(rs_entry_id=0) - yield Settle() - assert (yield self.m._dut.select.ready) == 1 - assert (yield from self.m.select.call())["rs_entry_id"] == 0 + await self.m.take.call(sim, rs_entry_id=0) + assert sim.get(self.m._dut.select.ready) == 1 + assert (await self.m.select.call(sim)).rs_entry_id == 0 # After reservation, select is false again - yield Settle() - assert (yield self.m._dut.select.ready) == 0 + assert sim.get(self.m._dut.select.ready) == 0 class TestRSMethodUpdate(TestCaseWithSimulator): @@ -266,7 +255,7 @@ def test_update(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -279,34 +268,31 @@ def test_update(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # Insert all reacords for record in self.insert_list: - yield from self.m.insert.call(record) - yield Settle() + await self.m.insert.call(sim, record) # Check data integrity for expected, record in zip(self.check_list, self.m._dut.data): - assert expected == (yield from get_outputs(record)) + assert expected == data_const_to_dict(sim.get(record)) # Update second entry first SP, instruction should be not ready value_sp1 = 1010 - assert (yield self.m._dut.data_ready[1]) == 0 - yield from self.m.update.call(reg_id=2, reg_val=value_sp1) - yield Settle() - assert (yield self.m._dut.data[1].rs_data.rp_s1) == 0 - assert (yield self.m._dut.data[1].rs_data.s1_val) == value_sp1 - assert (yield self.m._dut.data_ready[1]) == 0 + assert sim.get(self.m._dut.data_ready[1]) == 0 + await self.m.update.call(sim, reg_id=2, reg_val=value_sp1) + assert sim.get(self.m._dut.data[1].rs_data.rp_s1) == 0 + assert sim.get(self.m._dut.data[1].rs_data.s1_val) == value_sp1 + assert sim.get(self.m._dut.data_ready[1]) == 0 # Update second entry second SP, instruction should be ready value_sp2 = 2020 - yield from self.m.update.call(reg_id=3, reg_val=value_sp2) - yield Settle() - assert (yield self.m._dut.data[1].rs_data.rp_s2) == 0 - assert (yield self.m._dut.data[1].rs_data.s2_val) == value_sp2 - assert (yield self.m._dut.data_ready[1]) == 1 + await self.m.update.call(sim, reg_id=3, reg_val=value_sp2) + assert sim.get(self.m._dut.data[1].rs_data.rp_s2) == 0 + assert sim.get(self.m._dut.data[1].rs_data.s2_val) == value_sp2 + assert sim.get(self.m._dut.data_ready[1]) == 1 # Insert new instruction to entries 0 and 1, check if update of multiple registers works reg_id = 4 @@ -319,7 +305,7 @@ def simulation_process(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": 0, "s2_val": 0, @@ -327,18 +313,16 @@ def simulation_process(self): } for index in range(2): - yield from self.m.insert.call(rs_entry_id=index, rs_data=data) - yield Settle() - assert (yield self.m._dut.data_ready[index]) == 0 + await self.m.insert.call(sim, rs_entry_id=index, rs_data=data) + assert sim.get(self.m._dut.data_ready[index]) == 0 - yield from self.m.update.call(reg_id=reg_id, reg_val=value_spx) - yield Settle() + await self.m.update.call(sim, reg_id=reg_id, reg_val=value_spx) for index in range(2): - assert (yield self.m._dut.data[index].rs_data.rp_s1) == 0 - assert (yield self.m._dut.data[index].rs_data.rp_s2) == 0 - assert (yield self.m._dut.data[index].rs_data.s1_val) == value_spx - assert (yield self.m._dut.data[index].rs_data.s2_val) == value_spx - assert (yield self.m._dut.data_ready[index]) == 1 + assert sim.get(self.m._dut.data[index].rs_data.rp_s1) == 0 + assert sim.get(self.m._dut.data[index].rs_data.rp_s2) == 0 + assert sim.get(self.m._dut.data[index].rs_data.s1_val) == value_spx + assert sim.get(self.m._dut.data[index].rs_data.s2_val) == value_spx + assert sim.get(self.m._dut.data_ready[index]) == 1 class TestRSMethodTake(TestCaseWithSimulator): @@ -357,7 +341,7 @@ def test_take(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -370,37 +354,33 @@ def test_take(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # After each insert, entry should be marked as full for record in self.insert_list: - yield from self.m.insert.call(record) - yield Settle() + await self.m.insert.call(sim, record) # Check data integrity for expected, record in zip(self.check_list, self.m._dut.data): - assert expected == (yield from get_outputs(record)) + assert expected == data_const_to_dict(sim.get(record)) # Take first instruction - assert (yield self.m._dut.get_ready_list[0].ready) == 1 - data = yield from self.m.take.call(rs_entry_id=0) + assert sim.get(self.m._dut.get_ready_list[0].ready) == 1 + data = data_const_to_dict(await self.m.take.call(sim, rs_entry_id=0)) for key in data: assert data[key] == self.check_list[0]["rs_data"][key] - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 0 + assert sim.get(self.m._dut.get_ready_list[0].ready) == 0 # Update second instuction and take it reg_id = 2 value_spx = 1 - yield from self.m.update.call(reg_id=reg_id, reg_val=value_spx) - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 1 - data = yield from self.m.take.call(rs_entry_id=1) + await self.m.update.call(sim, reg_id=reg_id, reg_val=value_spx) + assert sim.get(self.m._dut.get_ready_list[0].ready) == 1 + data = data_const_to_dict(await self.m.take.call(sim, rs_entry_id=1)) for key in data: assert data[key] == self.check_list[1]["rs_data"][key] - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 0 + assert sim.get(self.m._dut.get_ready_list[0].ready) == 0 # Insert two new ready instructions and take them reg_id = 0 @@ -413,7 +393,7 @@ def simulation_process(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": 0, "s2_val": 0, @@ -422,22 +402,19 @@ def simulation_process(self): } for index in range(2): - yield from self.m.insert.call(rs_entry_id=index, rs_data=entry_data) - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 1 - assert (yield self.m._dut.data_ready[index]) == 1 + await self.m.insert.call(sim, rs_entry_id=index, rs_data=entry_data) + assert sim.get(self.m._dut.get_ready_list[0].ready) == 1 + assert sim.get(self.m._dut.data_ready[index]) == 1 - data = yield from self.m.take.call(rs_entry_id=0) + data = data_const_to_dict(await self.m.take.call(sim, rs_entry_id=0)) for key in data: assert data[key] == entry_data[key] - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 1 + assert sim.get(self.m._dut.get_ready_list[0].ready) == 1 - data = yield from self.m.take.call(rs_entry_id=1) + data = data_const_to_dict(await self.m.take.call(sim, rs_entry_id=1)) for key in data: assert data[key] == entry_data[key] - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 0 + assert sim.get(self.m._dut.get_ready_list[0].ready) == 0 class TestRSMethodGetReadyList(TestCaseWithSimulator): @@ -456,7 +433,7 @@ def test_get_ready_list(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -469,28 +446,25 @@ def test_get_ready_list(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # After each insert, entry should be marked as full for record in self.insert_list: - yield from self.m.insert.call(record) - yield Settle() + await self.m.insert.call(sim, record) # Check ready vector integrity - ready_list = (yield from self.m.get_ready_list[0].call())["ready_list"] + ready_list = (await self.m.get_ready_list[0].call(sim)).ready_list assert ready_list == 0b0011 # Take first record and check ready vector integrity - yield from self.m.take.call(rs_entry_id=0) - yield Settle() - ready_list = (yield from self.m.get_ready_list[0].call())["ready_list"] + await self.m.take.call(sim, rs_entry_id=0) + ready_list = (await self.m.get_ready_list[0].call(sim)).ready_list assert ready_list == 0b0010 # Take second record and check ready vector integrity - yield from self.m.take.call(rs_entry_id=1) - yield Settle() - option_ready_list = yield from self.m.get_ready_list[0].call_try() + await self.m.take.call(sim, rs_entry_id=1) + option_ready_list = await self.m.get_ready_list[0].call_try(sim) assert option_ready_list is None @@ -500,7 +474,7 @@ def test_two_get_ready_lists(self): self.rs_entries = self.gen_params.max_rs_entries self.rs_entries_bits = self.gen_params.max_rs_entries_bits self.m = SimpleTestCircuit( - RS(self.gen_params, 2**self.rs_entries_bits, 0, [[OpType(1), OpType(2)], [OpType(3), OpType(4)]]) + RS(self.gen_params, 2**self.rs_entries_bits, 0, [[OpType(1), OpType(2)], [OpType(3), OpType(4)]]), ) self.insert_list = [ { @@ -513,7 +487,7 @@ def test_two_get_ready_lists(self): "exec_fn": { "op_type": OpType(id + 1), "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -525,29 +499,27 @@ def test_two_get_ready_lists(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # After each insert, entry should be marked as full for record in self.insert_list: - yield from self.m.insert.call(record) - yield Settle() + await self.m.insert.call(sim, record) masks = [0b0011, 0b1100] for i in range(self.m._dut.rs_entries + 1): # Check ready vectors' integrity for j in range(2): - ready_list = yield from self.m.get_ready_list[j].call_try() + ready_list = await self.m.get_ready_list[j].call_try(sim) if masks[j]: - assert ready_list == {"ready_list": masks[j]} + assert ready_list.ready_list == masks[j] else: assert ready_list is None # Take a record if i == self.m._dut.rs_entries: break - yield from self.m.take.call(rs_entry_id=i) - yield Settle() + await self.m.take.call(sim, rs_entry_id=i) masks = [mask & ~(1 << i) for mask in masks] diff --git a/test/func_blocks/fu/fpu/test_fpu_error.py b/test/func_blocks/fu/fpu/test_fpu_error.py index 6938bfd17..fb4310131 100644 --- a/test/func_blocks/fu/fpu/test_fpu_error.py +++ b/test/func_blocks/fu/fpu/test_fpu_error.py @@ -45,7 +45,7 @@ def __init__(self, params: FPUParams): def test_special_cases(self, params: FPUParams, help_values: HelpValues): fpue = TestFPUError.FPUErrorModule(params) - def other_cases_test(): + async def other_cases_test(sim: TestbenchContext): test_cases = [ # No errors { @@ -53,7 +53,7 @@ def other_cases_test(): "sig": help_values.not_max_norm_even_sig, "exp": help_values.not_max_norm_exp, "inexact": 0, - "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY.value, "invalid_operation": 0, "division_by_zero": 0, "input_inf": 0, @@ -64,7 +64,7 @@ def other_cases_test(): "sig": help_values.not_max_norm_even_sig, "exp": help_values.not_max_norm_exp, "inexact": 1, - "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY.value, "invalid_operation": 0, "division_by_zero": 0, "input_inf": 0, @@ -75,7 +75,7 @@ def other_cases_test(): "sig": help_values.sub_norm_sig, "exp": 0, "inexact": 1, - "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY.value, "invalid_operation": 0, "division_by_zero": 0, "input_inf": 0, @@ -86,7 +86,7 @@ def other_cases_test(): "sig": help_values.qnan, "exp": help_values.max_exp, "inexact": 1, - "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY.value, "invalid_operation": 1, "division_by_zero": 0, "input_inf": 0, @@ -97,7 +97,7 @@ def other_cases_test(): "sig": 0, "exp": help_values.max_exp, "inexact": 1, - "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY.value, "invalid_operation": 0, "division_by_zero": 1, "input_inf": 0, @@ -108,7 +108,7 @@ def other_cases_test(): "sig": 0, "exp": help_values.max_exp, "inexact": 0, - "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY.value, "invalid_operation": 0, "division_by_zero": 0, "input_inf": 0, @@ -119,7 +119,7 @@ def other_cases_test(): "sig": help_values.sub_norm_sig, "exp": 0, "inexact": 0, - "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY.value, "invalid_operation": 0, "division_by_zero": 0, "input_inf": 0, @@ -130,7 +130,7 @@ def other_cases_test(): "sig": help_values.qnan, "exp": help_values.max_exp, "inexact": 1, - "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY.value, "invalid_operation": 0, "division_by_zero": 0, "input_inf": 0, @@ -141,7 +141,7 @@ def other_cases_test(): "sig": 0, "exp": help_values.max_exp, "inexact": 1, - "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY.value, "invalid_operation": 0, "division_by_zero": 0, "input_inf": 1, @@ -152,7 +152,7 @@ def other_cases_test(): "sig": help_values.min_norm_sig, "exp": 0, "inexact": 1, - "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY.value, "invalid_operation": 0, "division_by_zero": 0, "input_inf": 0, @@ -188,17 +188,17 @@ def other_cases_test(): ] for i in range(len(test_cases)): - resp = yield from fpue.error_checking_request_adapter.call(test_cases[i]) - assert resp["sign"] == expected_results[i]["sign"] - assert resp["exp"] == expected_results[i]["exp"] - assert resp["sig"] == expected_results[i]["sig"] - assert resp["errors"] == expected_results[i]["errors"] + resp = await fpue.error_checking_request_adapter.call(sim, test_cases[i]) + assert resp.sign == expected_results[i]["sign"] + assert resp.exp == expected_results[i]["exp"] + assert resp.sig == expected_results[i]["sig"] + assert resp.errors == expected_results[i]["errors"] - def test_process(): - yield from other_cases_test() + async def test_process(sim: TestbenchContext): + await other_cases_test(sim) with self.run_simulation(fpue) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) @parameterized.expand( [ @@ -261,14 +261,15 @@ def test_rounding( ): fpue = TestFPUError.FPUErrorModule(params) - def one_rounding_mode_test(): + async def one_rounding_mode_test(sim: TestbenchContext): + rm_int = rm.value # TODO: workaround for amaranth bug test_cases = [ # overflow detection { "sign": 0, "sig": 0, "exp": help_values.max_exp, - "rounding_mode": rm, + "rounding_mode": rm_int, "inexact": 0, "invalid_operation": 0, "division_by_zero": 0, @@ -278,7 +279,7 @@ def one_rounding_mode_test(): "sign": 1, "sig": 0, "exp": help_values.max_exp, - "rounding_mode": rm, + "rounding_mode": rm_int, "inexact": 0, "invalid_operation": 0, "division_by_zero": 0, @@ -292,14 +293,14 @@ def one_rounding_mode_test(): ] for i in range(len(test_cases)): - resp = yield from fpue.error_checking_request_adapter.call(test_cases[i]) + resp = await fpue.error_checking_request_adapter.call(sim, test_cases[i]) assert resp["sign"] == expected_results[i]["sign"] assert resp["exp"] == expected_results[i]["exp"] assert resp["sig"] == expected_results[i]["sig"] assert resp["errors"] == expected_results[i]["errors"] - def test_process(): - yield from one_rounding_mode_test() + async def test_process(sim: TestbenchContext): + await one_rounding_mode_test(sim) with self.run_simulation(fpue) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) diff --git a/test/func_blocks/fu/fpu/test_fpu_rounding.py b/test/func_blocks/fu/fpu/test_fpu_rounding.py index 0b1e40865..86a132f3d 100644 --- a/test/func_blocks/fu/fpu/test_fpu_rounding.py +++ b/test/func_blocks/fu/fpu/test_fpu_rounding.py @@ -95,7 +95,8 @@ def test_rounding( ): fpurt = TestFPURounding.FPURoundingModule(params) - def one_rounding_mode_test(): + async def one_rounding_mode_test(sim: TestbenchContext): + rm_int = rm.value # TODO: workaround for Amaranth bug test_cases = [ # carry after increment { @@ -104,7 +105,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 1, "sticky_bit": 1, - "rounding_mode": rm, + "rounding_mode": rm_int, }, # no overflow 00 { @@ -113,7 +114,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 0, "sticky_bit": 0, - "rounding_mode": rm, + "rounding_mode": rm_int, }, { "sign": 1, @@ -121,7 +122,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 0, "sticky_bit": 0, - "rounding_mode": rm, + "rounding_mode": rm_int, }, # no overflow 10 { @@ -130,7 +131,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 1, "sticky_bit": 0, - "rounding_mode": rm, + "rounding_mode": rm_int, }, { "sign": 1, @@ -138,7 +139,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 1, "sticky_bit": 0, - "rounding_mode": rm, + "rounding_mode": rm_int, }, # no overflow 01 { @@ -147,7 +148,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 0, "sticky_bit": 1, - "rounding_mode": rm, + "rounding_mode": rm_int, }, { "sign": 1, @@ -155,7 +156,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 0, "sticky_bit": 1, - "rounding_mode": rm, + "rounding_mode": rm_int, }, # no overflow 11 { @@ -164,7 +165,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 1, "sticky_bit": 1, - "rounding_mode": rm, + "rounding_mode": rm_int, }, { "sign": 1, @@ -172,7 +173,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 1, "sticky_bit": 1, - "rounding_mode": rm, + "rounding_mode": rm_int, }, # Round to nearest tie to even { @@ -181,7 +182,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 1, "sticky_bit": 0, - "rounding_mode": rm, + "rounding_mode": rm_int, }, { "sign": 0, @@ -189,7 +190,7 @@ def one_rounding_mode_test(): "exp": help_values.not_max_norm_exp, "round_bit": 1, "sticky_bit": 0, - "rounding_mode": rm, + "rounding_mode": rm_int, }, ] expected_results = [ @@ -264,13 +265,13 @@ def one_rounding_mode_test(): for i in range(num_of_test_cases): - resp = yield from fpurt.rounding_request_adapter.call(test_cases[i]) - assert resp["exp"] == expected_results[i]["exp"] - assert resp["sig"] == expected_results[i]["sig"] - assert resp["inexact"] == expected_results[i]["inexact"] + resp = await fpurt.rounding_request_adapter.call(sim, test_cases[i]) + assert resp.exp == expected_results[i]["exp"] + assert resp.sig == expected_results[i]["sig"] + assert resp.inexact == expected_results[i]["inexact"] - def test_process(): - yield from one_rounding_mode_test() + async def test_process(sim: TestbenchContext): + await one_rounding_mode_test(sim) with self.run_simulation(fpurt) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) diff --git a/test/func_blocks/fu/functional_common.py b/test/func_blocks/fu/functional_common.py index 088c4337d..85d34ab1d 100644 --- a/test/func_blocks/fu/functional_common.py +++ b/test/func_blocks/fu/functional_common.py @@ -6,11 +6,11 @@ from typing import Generic, TypeVar from amaranth import Elaboratable, Signal -from amaranth.sim import Passive, Tick from coreblocks.params import GenParams from coreblocks.params.configurations import test_core_config from coreblocks.priv.csr.csr_instances import GenericCSRRegisters +from transactron.testing.functions import data_const_to_dict from transactron.utils.dependencies import DependencyContext from coreblocks.params.fu_params import FunctionalComponentParams from coreblocks.arch import Funct3, Funct7 @@ -18,7 +18,14 @@ from coreblocks.interface.layouts import ExceptionRegisterLayouts from coreblocks.arch.optypes import OpType from transactron.lib import Adapter -from transactron.testing import RecordIntDict, RecordIntDictRet, TestbenchIO, TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import ( + RecordIntDict, + TestbenchIO, + TestCaseWithSimulator, + SimpleTestCircuit, + ProcessContext, + TestbenchContext, +) from transactron.utils import ModuleConnector @@ -111,8 +118,8 @@ def setup(self, fixture_initialize_testing_env): random.seed(self.seed) self.requests = deque[RecordIntDict]() - self.responses = deque[RecordIntDictRet]() - self.exceptions = deque[RecordIntDictRet]() + self.responses = deque[RecordIntDict]() + self.exceptions = deque[RecordIntDict]() max_int = 2**self.gen_params.isa.xlen - 1 functions = list(self.ops.keys()) @@ -158,37 +165,38 @@ def setup(self, fixture_initialize_testing_env): self.responses.append({"rob_id": rob_id, "rp_dst": rp_dst, "exception": int(cause is not None)} | results) - def consumer(self): + async def consumer(self, sim: TestbenchContext): while self.responses: expected = self.responses.pop() - result = yield from self.m.accept.call() - assert expected == result - yield from self.random_wait(self.max_wait) + result = await self.m.accept.call(sim) + assert expected == data_const_to_dict(result) + await self.random_wait(sim, self.max_wait) - def producer(self): + async def producer(self, sim: TestbenchContext): while self.requests: req = self.requests.pop() - yield from self.m.issue.call(req) - yield from self.random_wait(self.max_wait) - - def exception_consumer(self): - while self.exceptions: - expected = self.exceptions.pop() - result = yield from self.report_mock.call() - assert expected == result - yield from self.random_wait(self.max_wait) + await self.m.issue.call(sim, req) + await self.random_wait(sim, self.max_wait) + + async def exception_consumer(self, sim: TestbenchContext): + # This is a background testbench so that extra calls can be detected reliably + with sim.critical(): + while self.exceptions: + expected = self.exceptions.pop() + result = await self.report_mock.call(sim) + assert expected == data_const_to_dict(result) + await self.random_wait(sim, self.max_wait) # keep partialy dependent tests from hanging up and detect extra calls - yield Passive() - result = yield from self.report_mock.call() + result = await self.report_mock.call(sim) assert not True, "unexpected report call" - def pipeline_verifier(self): - yield Passive() - while True: - assert (yield self.m.issue.adapter.iface.ready) - assert (yield self.m.issue.adapter.en) == (yield self.m.issue.adapter.done) - yield Tick() + async def pipeline_verifier(self, sim: ProcessContext): + async for *_, ready, en, done in sim.tick().sample( + self.m.issue.adapter.iface.ready, self.m.issue.adapter.en, self.m.issue.adapter.done + ): + assert ready + assert en == done def run_standard_fu_test(self, pipeline_test=False): if pipeline_test: @@ -197,8 +205,8 @@ def run_standard_fu_test(self, pipeline_test=False): self.max_wait = 10 with self.run_simulation(self.circ) as sim: - sim.add_process(self.producer) - sim.add_process(self.consumer) - sim.add_process(self.exception_consumer) + sim.add_testbench(self.producer) + sim.add_testbench(self.consumer) + sim.add_testbench(self.exception_consumer, background=True) if pipeline_test: sim.add_process(self.pipeline_verifier) diff --git a/test/func_blocks/fu/test_fu_decoder.py b/test/func_blocks/fu/test_fu_decoder.py index 9c6601f3c..cedaf93b1 100644 --- a/test/func_blocks/fu/test_fu_decoder.py +++ b/test/func_blocks/fu/test_fu_decoder.py @@ -1,9 +1,7 @@ import random -from typing import Sequence, Generator -from amaranth import * -from amaranth.sim import * +from collections.abc import Sequence -from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator +from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchContext from coreblocks.func_blocks.fu.common.fu_decoder import DecoderManager, Decoder from coreblocks.arch import OpType, Funct3, Funct7 @@ -31,21 +29,19 @@ def expected_results(self, instructions: Sequence[tuple], op_type_dependent: boo return acc - def handle_signals(self, decoder: Decoder, exec_fn: dict[str, int]) -> Generator: - yield decoder.exec_fn.op_type.eq(exec_fn["op_type"]) - yield decoder.exec_fn.funct3.eq(exec_fn["funct3"]) - yield decoder.exec_fn.funct7.eq(exec_fn["funct7"]) + async def handle_signals(self, sim: TestbenchContext, decoder: Decoder, exec_fn: dict[str, int]): + sim.set(decoder.exec_fn.op_type, exec_fn["op_type"]) + sim.set(decoder.exec_fn.funct3, exec_fn["funct3"]) + sim.set(decoder.exec_fn.funct7, exec_fn["funct7"]) - yield Settle() - - return (yield decoder.decode_fn) + return sim.get(decoder.decode_fn) def run_test_case(self, decoder_manager: DecoderManager, test_inputs: Sequence[tuple]) -> None: instructions = decoder_manager.get_instructions() decoder = decoder_manager.get_decoder(self.gen_params) op_type_dependent = len(decoder_manager.get_op_types()) != 1 - def process(): + async def process(sim: TestbenchContext): for test_input in test_inputs: exec_fn = { "op_type": test_input[1], @@ -53,7 +49,7 @@ def process(): "funct7": test_input[3] if len(test_input) >= 4 else 0, } - returned = yield from self.handle_signals(decoder, exec_fn) + returned = await self.handle_signals(sim, decoder, exec_fn) expected = self.expected_results(instructions, op_type_dependent, exec_fn) assert returned == expected @@ -61,7 +57,7 @@ def process(): test_circuit = SimpleTestCircuit(decoder) with self.run_simulation(test_circuit) as sim: - sim.add_process(process) + sim.add_testbench(process) def generate_random_instructions(self) -> Sequence[tuple]: random.seed(42) diff --git a/test/func_blocks/fu/test_pipelined_mul_unit.py b/test/func_blocks/fu/test_pipelined_mul_unit.py index 1c955c6b4..20b46ff14 100644 --- a/test/func_blocks/fu/test_pipelined_mul_unit.py +++ b/test/func_blocks/fu/test_pipelined_mul_unit.py @@ -2,15 +2,15 @@ import math from collections import deque -from amaranth.sim import Settle from parameterized import parameterized_class from coreblocks.func_blocks.fu.unsigned_multiplication.pipelined import PipelinedUnsignedMul -from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit, TestbenchContext from coreblocks.params import GenParams from coreblocks.params.configurations import test_core_config +from transactron.testing.functions import data_const_to_dict @parameterized_class( @@ -57,14 +57,14 @@ def setup_method(self): ) def test_pipeline(self): - def consumer(): + async def consumer(sim: TestbenchContext): time = 0 while self.responses: - res = yield from self.m.accept.call_try() + res = await self.m.accept.call_try(sim) time += 1 if res is not None: expected = self.responses.pop() - assert expected == res + assert expected == data_const_to_dict(res) assert ( time @@ -73,12 +73,11 @@ def consumer(): + 2 ) - def producer(): + async def producer(sim: TestbenchContext): while self.requests: req = self.requests.pop() - yield Settle() - yield from self.m.issue.call(req) + await self.m.issue.call(sim, req) with self.run_simulation(self.m) as sim: - sim.add_process(producer) - sim.add_process(consumer) + sim.add_testbench(producer) + sim.add_testbench(consumer) diff --git a/test/func_blocks/fu/test_unsigned_mul_unit.py b/test/func_blocks/fu/test_unsigned_mul_unit.py index 06321672c..bb522c73c 100644 --- a/test/func_blocks/fu/test_unsigned_mul_unit.py +++ b/test/func_blocks/fu/test_unsigned_mul_unit.py @@ -1,8 +1,6 @@ import random from collections import deque -from typing import Type -from amaranth.sim import Settle from parameterized import parameterized_class from coreblocks.func_blocks.fu.unsigned_multiplication.common import MulBaseUnsigned @@ -11,10 +9,11 @@ from coreblocks.func_blocks.fu.unsigned_multiplication.shift import ShiftUnsignedMul from coreblocks.func_blocks.fu.unsigned_multiplication.pipelined import PipelinedUnsignedMul -from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit, TestbenchContext from coreblocks.params import GenParams from coreblocks.params.configurations import test_core_config +from transactron.testing.functions import data_const_to_dict @parameterized_class( @@ -39,7 +38,7 @@ ], ) class TestUnsignedMultiplicationUnit(TestCaseWithSimulator): - mul_unit: Type[MulBaseUnsigned] + mul_unit: type[MulBaseUnsigned] def setup_method(self): self.gen_params = GenParams(test_core_config) @@ -68,20 +67,19 @@ def setup_method(self): ) def test_pipeline(self): - def consumer(): + async def consumer(sim: TestbenchContext): while self.responses: expected = self.responses.pop() - result = yield from self.m.accept.call() - assert expected == result - yield from self.random_wait(self.waiting_time) + result = await self.m.accept.call(sim) + assert expected == data_const_to_dict(result) + await self.random_wait(sim, self.waiting_time) - def producer(): + async def producer(sim: TestbenchContext): while self.requests: req = self.requests.pop() - yield Settle() - yield from self.m.issue.call(req) - yield from self.random_wait(self.waiting_time) + await self.m.issue.call(sim, req) + await self.random_wait(sim, self.waiting_time) with self.run_simulation(self.m) as sim: - sim.add_process(producer) - sim.add_process(consumer) + sim.add_testbench(producer) + sim.add_testbench(consumer) diff --git a/test/func_blocks/lsu/test_dummylsu.py b/test/func_blocks/lsu/test_dummylsu.py index b0e1a702f..3a13149dc 100644 --- a/test/func_blocks/lsu/test_dummylsu.py +++ b/test/func_blocks/lsu/test_dummylsu.py @@ -1,9 +1,8 @@ import random from collections import deque -from amaranth.sim import Settle, Passive, Tick - from transactron.lib import Adapter +from transactron.testing.method_mock import MethodMock from transactron.utils import int_to_signed, signed_to_int from coreblocks.params import GenParams from coreblocks.func_blocks.fu.lsu.dummyLsu import LSUDummy @@ -13,7 +12,7 @@ from transactron.utils.dependencies import DependencyContext from coreblocks.interface.layouts import ExceptionRegisterLayouts, RetirementLayouts from coreblocks.peripherals.wishbone import * -from transactron.testing import TestbenchIO, TestCaseWithSimulator, def_method_mock +from transactron.testing import TestbenchIO, TestCaseWithSimulator, def_method_mock, TestbenchContext from coreblocks.peripherals.bus_adapter import WishboneMasterAdapter from test.peripherals.test_wishbone import WishboneInterfaceWrapper @@ -199,11 +198,8 @@ def setup_method(self) -> None: self.generate_instr(2**7, 2**7) self.max_wait = 10 - def wishbone_slave(self): - yield Passive() - - while True: - yield from self.test_module.io_in.slave_wait() + async def wishbone_slave(self, sim: TestbenchContext): + while self.mem_data_queue: generated_data = self.mem_data_queue.pop() if generated_data["misaligned"]: @@ -211,8 +207,8 @@ def wishbone_slave(self): mask = generated_data["mask"] sign = generated_data["sign"] - yield from self.test_module.io_in.slave_verify(generated_data["addr"], 0, 0, mask) - yield from self.random_wait(self.max_wait) + await self.test_module.io_in.slave_wait_and_verify(sim, generated_data["addr"], 0, 0, mask) + await self.random_wait(sim, self.max_wait) resp_data = int((generated_data["rnd_bytes"][:4]).hex(), 16) data_shift = (mask & -mask).bit_length() - 1 @@ -225,21 +221,20 @@ def wishbone_slave(self): data = int_to_signed(signed_to_int(data, size), 32) if not generated_data["err"]: self.returned_data.appendleft(data) - yield from self.test_module.io_in.slave_respond(resp_data, err=generated_data["err"]) - yield Settle() + await self.test_module.io_in.slave_respond(sim, resp_data, err=generated_data["err"]) - def inserter(self): + async def inserter(self, sim: TestbenchContext): for i in range(self.tests_number): req = self.instr_queue.pop() while req["rob_id"] not in self.free_rob_id: - yield Tick() + await sim.tick() self.free_rob_id.remove(req["rob_id"]) - yield from self.test_module.issue.call(req) - yield from self.random_wait(self.max_wait) + await self.test_module.issue.call(sim, req) + await self.random_wait(sim, self.max_wait) - def consumer(self): + async def consumer(self, sim: TestbenchContext): for i in range(self.tests_number): - v = yield from self.test_module.accept.call() + v = await self.test_module.accept.call(sim) rob_id = v["rob_id"] assert rob_id not in self.free_rob_id self.free_rob_id.add(rob_id) @@ -250,12 +245,14 @@ def consumer(self): assert v["result"] == self.returned_data.pop() assert v["exception"] == exc["err"] - yield from self.random_wait(self.max_wait) + await self.random_wait(sim, self.max_wait) def test(self): @def_method_mock(lambda: self.test_module.exception_report) def exception_consumer(arg): - assert arg == self.exception_queue.pop() + @MethodMock.effect + def eff(): + assert arg == self.exception_queue.pop() @def_method_mock(lambda: self.test_module.precommit, validate_arguments=lambda rob_id: True) def precommiter(rob_id): @@ -266,9 +263,9 @@ def core_state_process(): return {"flushing": 0} with self.run_simulation(self.test_module) as sim: - sim.add_process(self.wishbone_slave) - sim.add_process(self.inserter) - sim.add_process(self.consumer) + sim.add_testbench(self.wishbone_slave, background=True) + sim.add_testbench(self.inserter) + sim.add_testbench(self.consumer) class TestDummyLSULoadsCycles(TestCaseWithSimulator): @@ -301,33 +298,33 @@ def setup_method(self) -> None: self.gen_params = GenParams(test_core_config.replace(phys_regs_bits=3, rob_entries_bits=3)) self.test_module = DummyLSUTestCircuit(self.gen_params) - def one_instr_test(self): + async def one_instr_test(self, sim: TestbenchContext): instr, wish_data = self.generate_instr(2**7, 2**7) - yield from self.test_module.issue.call(instr) - yield from self.test_module.io_in.slave_wait() + await self.test_module.issue.call(sim, instr) mask = wish_data["mask"] - yield from self.test_module.io_in.slave_verify(wish_data["addr"], 0, 0, mask) + await self.test_module.io_in.slave_wait_and_verify(sim, wish_data["addr"], 0, 0, mask) data = wish_data["rnd_bytes"][:4] data = int(data.hex(), 16) - yield from self.test_module.io_in.slave_respond(data) - yield Settle() + await self.test_module.io_in.slave_respond(sim, data) - v = yield from self.test_module.accept.call() + v = await self.test_module.accept.call(sim) assert v["result"] == data def test(self): @def_method_mock(lambda: self.test_module.exception_report) def exception_consumer(arg): - assert False + @MethodMock.effect + def eff(): + assert False @def_method_mock(lambda: self.test_module.precommit, validate_arguments=lambda rob_id: True) def precommiter(rob_id): return {"side_fx": 1} with self.run_simulation(self.test_module) as sim: - sim.add_process(self.one_instr_test) + sim.add_testbench(self.one_instr_test) class TestDummyLSUStores(TestCaseWithSimulator): @@ -381,9 +378,8 @@ def setup_method(self) -> None: self.generate_instr(2**7, 2**7) self.max_wait = 8 - def wishbone_slave(self): + async def wishbone_slave(self, sim: TestbenchContext): for i in range(self.tests_number): - yield from self.test_module.io_in.slave_wait() generated_data = self.mem_data_queue.pop() mask = generated_data["mask"] @@ -395,27 +391,26 @@ def wishbone_slave(self): data = (int(generated_data["data"][-2:].hex(), 16) & 0xFFFF) << h_dict[mask] else: data = int(generated_data["data"][-4:].hex(), 16) - yield from self.test_module.io_in.slave_verify(generated_data["addr"], data, 1, mask) - yield from self.random_wait(self.max_wait) + await self.test_module.io_in.slave_wait_and_verify(sim, generated_data["addr"], data, 1, mask) + await self.random_wait(sim, self.max_wait) - yield from self.test_module.io_in.slave_respond(0) - yield Settle() + await self.test_module.io_in.slave_respond(sim, 0) - def inserter(self): + async def inserter(self, sim: TestbenchContext): for i in range(self.tests_number): req = self.instr_queue.pop() self.get_result_data.appendleft(req["rob_id"]) - yield from self.test_module.issue.call(req) + await self.test_module.issue.call(sim, req) self.precommit_data.appendleft(req["rob_id"]) - yield from self.random_wait(self.max_wait) + await self.random_wait(sim, self.max_wait) - def get_resulter(self): + async def get_resulter(self, sim: TestbenchContext): for i in range(self.tests_number): - v = yield from self.test_module.accept.call() + v = await self.test_module.accept.call(sim) rob_id = self.get_result_data.pop() assert v["rob_id"] == rob_id assert v["rp_dst"] == 0 - yield from self.random_wait(self.max_wait) + await self.random_wait(sim, self.max_wait) self.precommit_data.pop() # retire def precommit_validate(self, rob_id): @@ -428,37 +423,37 @@ def precommiter(self, rob_id): def test(self): @def_method_mock(lambda: self.test_module.exception_report) def exception_consumer(arg): - assert False + @MethodMock.effect + def eff(): + assert False with self.run_simulation(self.test_module) as sim: - sim.add_process(self.wishbone_slave) - sim.add_process(self.inserter) - sim.add_process(self.get_resulter) - sim.add_process(self.precommiter) + sim.add_testbench(self.wishbone_slave) + sim.add_testbench(self.inserter) + sim.add_testbench(self.get_resulter) class TestDummyLSUFence(TestCaseWithSimulator): def get_instr(self, exec_fn): return {"rp_dst": 1, "rob_id": 1, "exec_fn": exec_fn, "s1_val": 4, "s2_val": 1, "imm": 8, "pc": 0} - def push_one_instr(self, instr): - yield from self.test_module.issue.call(instr) + async def push_one_instr(self, sim: TestbenchContext, instr): + await self.test_module.issue.call(sim, instr) if instr["exec_fn"]["op_type"] == OpType.LOAD: - yield from self.test_module.io_in.slave_wait() - yield from self.test_module.io_in.slave_respond(1) - yield Settle() - v = yield from self.test_module.accept.call() + await self.test_module.io_in.slave_wait(sim) + await self.test_module.io_in.slave_respond(sim, 1) + v = await self.test_module.accept.call(sim) if instr["exec_fn"]["op_type"] == OpType.LOAD: - assert v["result"] == 1 + assert v.result == 1 - def process(self): + async def process(self, sim: TestbenchContext): # just tests if FENCE doens't hang up the LSU load_fn = {"op_type": OpType.LOAD, "funct3": Funct3.W, "funct7": 0} fence_fn = {"op_type": OpType.FENCE, "funct3": 0, "funct7": 0} - yield from self.push_one_instr(self.get_instr(load_fn)) - yield from self.push_one_instr(self.get_instr(fence_fn)) - yield from self.push_one_instr(self.get_instr(load_fn)) + await self.push_one_instr(sim, self.get_instr(load_fn)) + await self.push_one_instr(sim, self.get_instr(fence_fn)) + await self.push_one_instr(sim, self.get_instr(load_fn)) def test_fence(self): self.gen_params = GenParams(test_core_config.replace(phys_regs_bits=3, rob_entries_bits=3)) @@ -466,11 +461,13 @@ def test_fence(self): @def_method_mock(lambda: self.test_module.exception_report) def exception_consumer(arg): - assert False + @MethodMock.effect + def eff(): + assert False @def_method_mock(lambda: self.test_module.precommit, validate_arguments=lambda rob_id: True) def precommiter(rob_id): return {"side_fx": 1} with self.run_simulation(self.test_module) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) diff --git a/test/func_blocks/lsu/test_pma.py b/test/func_blocks/lsu/test_pma.py index 7f1addfe6..16e8aec4b 100644 --- a/test/func_blocks/lsu/test_pma.py +++ b/test/func_blocks/lsu/test_pma.py @@ -1,4 +1,4 @@ -from amaranth.sim import Settle +import random from coreblocks.func_blocks.fu.lsu.pma import PMAChecker, PMARegion from transactron.lib import Adapter @@ -7,25 +7,25 @@ from coreblocks.params.configurations import test_core_config from coreblocks.arch import * from coreblocks.interface.keys import CoreStateKey, ExceptionReportKey, InstructionPrecommitKey +from transactron.testing.method_mock import MethodMock from transactron.utils.dependencies import DependencyContext from coreblocks.interface.layouts import ExceptionRegisterLayouts, RetirementLayouts from coreblocks.peripherals.wishbone import * -from transactron.testing import TestbenchIO, TestCaseWithSimulator, def_method_mock +from transactron.testing import TestbenchIO, TestCaseWithSimulator, def_method_mock, TestbenchContext from coreblocks.peripherals.bus_adapter import WishboneMasterAdapter from test.peripherals.test_wishbone import WishboneInterfaceWrapper class TestPMADirect(TestCaseWithSimulator): - def verify_region(self, region: PMARegion): + async def verify_region(self, sim: TestbenchContext, region: PMARegion): for i in range(region.start, region.end + 1): - yield self.test_module.addr.eq(i) - yield Settle() - mmio = yield self.test_module.result["mmio"] + sim.set(self.test_module.addr, i) + mmio = sim.get(self.test_module.result.mmio) assert mmio == region.mmio - def process(self): + async def process(self, sim: TestbenchContext): for r in self.pma_regions: - yield from self.verify_region(r) + await self.verify_region(sim, r) def test_pma_direct(self): self.pma_regions = [ @@ -40,7 +40,7 @@ def test_pma_direct(self): self.test_module = PMAChecker(self.gen_params) with self.run_simulation(self.test_module) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) class PMAIndirectTestCircuit(Elaboratable): @@ -100,29 +100,24 @@ def get_instr(self, addr): "pc": 0, } - def verify_region(self, region: PMARegion): + async def verify_region(self, sim: TestbenchContext, region: PMARegion): for addr in range(region.start, region.end + 1): instr = self.get_instr(addr) - yield from self.test_module.issue.call(instr) + await self.test_module.issue.call(sim, instr) if region.mmio is True: wb = self.test_module.io_in.wb for i in range(100): # 100 cycles is more than enough - wb_requested = (yield wb.stb) and (yield wb.cyc) + wb_requested = sim.get(wb.stb) and sim.get(wb.cyc) assert not wb_requested - yield from self.test_module.precommit.method_handle( - function=lambda rob_id: {"side_fx": 1}, validate_arguments=lambda rob_id: rob_id == 1 - ) + await self.test_module.io_in.slave_wait(sim) + await self.test_module.io_in.slave_respond(sim, (addr << (addr % 4) * 8)) + v = await self.test_module.accept.call(sim) + assert v.result == addr - yield from self.test_module.io_in.slave_wait() - yield from self.test_module.io_in.slave_respond((addr << (addr % 4) * 8)) - yield Settle() - v = yield from self.test_module.accept.call() - assert v["result"] == addr - - def process(self): + async def process(self, sim: TestbenchContext): for region in self.pma_regions: - yield from self.verify_region(region) + await self.verify_region(sim, region) def test_pma_indirect(self): self.pma_regions = [ @@ -135,11 +130,21 @@ def test_pma_indirect(self): @def_method_mock(lambda: self.test_module.exception_report) def exception_consumer(arg): - assert False + @MethodMock.effect + def eff(): + assert False + + @def_method_mock( + lambda: self.test_module.precommit, + validate_arguments=lambda rob_id: rob_id == 1, + enable=lambda: random.random() < 0.5, + ) + def precommiter(rob_id): + return {"side_fx": 1} @def_method_mock(lambda: self.test_module.core_state) def core_state_process(): return {"flushing": 0} with self.run_simulation(self.test_module) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) diff --git a/test/peripherals/test_axi_lite.py b/test/peripherals/test_axi_lite.py index 514e85ea9..d1156271b 100644 --- a/test/peripherals/test_axi_lite.py +++ b/test/peripherals/test_axi_lite.py @@ -9,66 +9,77 @@ class AXILiteInterfaceWrapper: def __init__(self, axi_lite_master: AXILiteInterface): self.axi_lite = axi_lite_master - def slave_ra_ready(self, rdy=1): - yield self.axi_lite.read_address.rdy.eq(rdy) - - def slave_ra_wait(self): - while not (yield self.axi_lite.read_address.valid): - yield Tick() - - def slave_ra_verify(self, exp_addr, prot): - assert (yield self.axi_lite.read_address.valid) - assert (yield self.axi_lite.read_address.addr) == exp_addr - assert (yield self.axi_lite.read_address.prot) == prot - - def slave_rd_wait(self): - while not (yield self.axi_lite.read_data.rdy): - yield Tick() - - def slave_rd_respond(self, data, resp=0): - assert (yield self.axi_lite.read_data.rdy) - yield self.axi_lite.read_data.data.eq(data) - yield self.axi_lite.read_data.resp.eq(resp) - yield self.axi_lite.read_data.valid.eq(1) - yield Tick() - yield self.axi_lite.read_data.valid.eq(0) - - def slave_wa_ready(self, rdy=1): - yield self.axi_lite.write_address.rdy.eq(rdy) - - def slave_wa_wait(self): - while not (yield self.axi_lite.write_address.valid): - yield Tick() - - def slave_wa_verify(self, exp_addr, prot): - assert (yield self.axi_lite.write_address.valid) - assert (yield self.axi_lite.write_address.addr) == exp_addr - assert (yield self.axi_lite.write_address.prot) == prot - - def slave_wd_ready(self, rdy=1): - yield self.axi_lite.write_data.rdy.eq(rdy) - - def slave_wd_wait(self): - while not (yield self.axi_lite.write_data.valid): - yield Tick() - - def slave_wd_verify(self, exp_data, strb): - assert (yield self.axi_lite.write_data.valid) - assert (yield self.axi_lite.write_data.data) == exp_data - assert (yield self.axi_lite.write_data.strb) == strb - - def slave_wr_wait(self): - while not (yield self.axi_lite.write_response.rdy): - yield Tick() - - def slave_wr_respond(self, resp=0): - assert (yield self.axi_lite.write_response.rdy) - yield self.axi_lite.write_response.resp.eq(resp) - yield self.axi_lite.write_response.valid.eq(1) - yield Tick() - yield self.axi_lite.write_response.valid.eq(0) - - + def slave_ra_ready(self, sim: TestbenchContext, rdy=1): + sim.set(self.axi_lite.read_address.rdy, rdy) + + def slave_ra_get(self, sim: TestbenchContext): + ra = self.axi_lite.read_address + assert sim.get(ra.valid) + return sim.get(ra.addr), sim.get(ra.prot) + + def slave_ra_get_and_verify(self, sim: TestbenchContext, exp_addr: int, exp_prot: int): + addr, prot = self.slave_ra_get(sim) + assert addr == exp_addr + assert prot == exp_prot + + async def slave_rd_wait(self, sim: TestbenchContext): + rd = self.axi_lite.read_data + while not sim.get(rd.rdy): + await sim.tick() + + def slave_rd_get(self, sim: TestbenchContext): + rd = self.axi_lite.read_data + assert sim.get(rd.rdy) + + async def slave_rd_respond(self, sim: TestbenchContext, data, resp=0): + assert sim.get(self.axi_lite.read_data.rdy) + sim.set(self.axi_lite.read_data.data, data) + sim.set(self.axi_lite.read_data.resp, resp) + sim.set(self.axi_lite.read_data.valid, 1) + await sim.tick() + sim.set(self.axi_lite.read_data.valid, 0) + + def slave_wa_ready(self, sim: TestbenchContext, rdy=1): + sim.set(self.axi_lite.write_address.rdy, rdy) + + def slave_wa_get(self, sim: TestbenchContext): + wa = self.axi_lite.write_address + assert sim.get(wa.valid) + return sim.get(wa.addr), sim.get(wa.prot) + + def slave_wa_get_and_verify(self, sim: TestbenchContext, exp_addr, exp_prot): + addr, prot = self.slave_wa_get(sim) + assert addr == exp_addr + assert prot == exp_prot + + def slave_wd_ready(self, sim: TestbenchContext, rdy=1): + sim.set(self.axi_lite.write_data.rdy, rdy) + + def slave_wd_get(self, sim: TestbenchContext): + wd = self.axi_lite.write_data + assert sim.get(wd.valid) + return sim.get(wd.data), sim.get(wd.strb) + + def slave_wd_get_and_verify(self, sim: TestbenchContext, exp_data, exp_strb): + data, strb = self.slave_wd_get(sim) + assert data == exp_data + assert strb == exp_strb + + def slave_wr_get(self, sim: TestbenchContext): + wr = self.axi_lite.write_response + assert sim.get(wr.rdy) + + async def slave_wr_respond(self, sim: TestbenchContext, resp=0): + assert sim.get(self.axi_lite.write_response.rdy) + sim.set(self.axi_lite.write_response.resp, resp) + sim.set(self.axi_lite.write_response.valid, 1) + await sim.tick() + sim.set(self.axi_lite.write_response.valid, 0) + + +# TODO: this test needs a rewrite! +# 1. use queues instead of copy-pasting +# 2. handle each AXI pipe independently class TestAXILiteMaster(TestCaseWithSimulator): class AXILiteMasterTestModule(Elaboratable): def __init__(self, params: AXILiteParameters): @@ -103,161 +114,141 @@ def _(arg): def test_manual(self): almt = TestAXILiteMaster.AXILiteMasterTestModule(AXILiteParameters()) - def master_process(): + async def master_process(sim: TestbenchContext): # read request - yield from almt.read_address_request_adapter.call(addr=5, prot=0) + await almt.read_address_request_adapter.call(sim, addr=5, prot=0) - yield from almt.read_address_request_adapter.call(addr=10, prot=1) + await almt.read_address_request_adapter.call(sim, addr=10, prot=1) - yield from almt.read_address_request_adapter.call(addr=15, prot=1) + await almt.read_address_request_adapter.call(sim, addr=15, prot=1) - yield from almt.read_address_request_adapter.call(addr=20, prot=0) + await almt.read_address_request_adapter.call(sim, addr=20, prot=0) - yield from almt.write_request_adapter.call(addr=6, prot=0, data=10, strb=3) + await almt.write_request_adapter.call(sim, addr=6, prot=0, data=10, strb=3) - yield from almt.write_request_adapter.call(addr=7, prot=0, data=11, strb=3) + await almt.write_request_adapter.call(sim, addr=7, prot=0, data=11, strb=3) - yield from almt.write_request_adapter.call(addr=8, prot=0, data=12, strb=3) + await almt.write_request_adapter.call(sim, addr=8, prot=0, data=12, strb=3) - yield from almt.write_request_adapter.call(addr=9, prot=1, data=13, strb=4) + await almt.write_request_adapter.call(sim, addr=9, prot=1, data=13, strb=4) - yield from almt.read_address_request_adapter.call(addr=1, prot=1) + await almt.read_address_request_adapter.call(sim, addr=1, prot=1) - yield from almt.read_address_request_adapter.call(addr=2, prot=1) + await almt.read_address_request_adapter.call(sim, addr=2, prot=1) - def slave_process(): + async def slave_process(sim: TestbenchContext): slave = AXILiteInterfaceWrapper(almt.axi_lite_master.axil_master) # 1st request - yield from slave.slave_ra_ready(1) - yield from slave.slave_ra_wait() - yield from slave.slave_ra_verify(5, 0) - yield Settle() + slave.slave_ra_ready(sim, 1) + await sim.tick() + slave.slave_ra_get_and_verify(sim, 5, 0) # 2nd request and 1st respond - yield from slave.slave_ra_wait() - yield from slave.slave_rd_wait() - yield from slave.slave_ra_verify(10, 1) - yield from slave.slave_rd_respond(10, 0) - yield Settle() + await sim.tick() + slave.slave_ra_get_and_verify(sim, 10, 1) + slave.slave_rd_get(sim) + await slave.slave_rd_respond(sim, 10, 0) # 3rd request and 2nd respond - yield from slave.slave_ra_wait() - yield from slave.slave_rd_wait() - yield from slave.slave_ra_verify(15, 1) - yield from slave.slave_rd_respond(15, 0) - yield Settle() + slave.slave_ra_get_and_verify(sim, 15, 1) + slave.slave_rd_get(sim) + await slave.slave_rd_respond(sim, 15, 0) # 4th request and 3rd respond - yield from slave.slave_ra_wait() - yield from slave.slave_rd_wait() - yield from slave.slave_ra_verify(20, 0) - yield from slave.slave_rd_respond(20, 0) - yield Settle() + slave.slave_ra_get_and_verify(sim, 20, 0) + slave.slave_rd_get(sim) + await slave.slave_rd_respond(sim, 20, 0) # 4th respond and 1st write request - yield from slave.slave_ra_ready(0) - yield from slave.slave_wa_ready(1) - yield from slave.slave_wd_ready(1) - yield from slave.slave_rd_wait() - yield from slave.slave_wa_wait() - yield from slave.slave_wd_wait() - yield from slave.slave_wa_verify(6, 0) - yield from slave.slave_wd_verify(10, 3) - yield from slave.slave_rd_respond(25, 0) - yield Settle() + slave.slave_ra_ready(sim, 0) + slave.slave_wa_ready(sim, 1) + slave.slave_wd_ready(sim, 1) + slave.slave_rd_get(sim) + slave.slave_wa_get_and_verify(sim, 6, 0) + slave.slave_wd_get_and_verify(sim, 10, 3) + await slave.slave_rd_respond(sim, 25, 0) # 2nd write request and 1st respond - yield from slave.slave_wa_wait() - yield from slave.slave_wd_wait() - yield from slave.slave_wr_wait() - yield from slave.slave_wa_verify(7, 0) - yield from slave.slave_wd_verify(11, 3) - yield from slave.slave_wr_respond(1) - yield Settle() + slave.slave_wa_get_and_verify(sim, 7, 0) + slave.slave_wd_get_and_verify(sim, 11, 3) + slave.slave_wr_get(sim) + await slave.slave_wr_respond(sim, 1) # 3nd write request and 2st respond - yield from slave.slave_wa_wait() - yield from slave.slave_wd_wait() - yield from slave.slave_wr_wait() - yield from slave.slave_wa_verify(8, 0) - yield from slave.slave_wd_verify(12, 3) - yield from slave.slave_wr_respond(1) - yield Settle() + slave.slave_wa_get_and_verify(sim, 8, 0) + slave.slave_wd_get_and_verify(sim, 12, 3) + slave.slave_wr_get(sim) + await slave.slave_wr_respond(sim, 1) # 4th write request and 3rd respond - yield from slave.slave_wr_wait() - yield from slave.slave_wa_verify(9, 1) - yield from slave.slave_wd_verify(13, 4) - yield from slave.slave_wr_respond(1) - yield Settle() + slave.slave_wa_get_and_verify(sim, 9, 1) + slave.slave_wd_get_and_verify(sim, 13, 4) + slave.slave_wr_get(sim) + await slave.slave_wr_respond(sim, 1) # 4th respond - yield from slave.slave_wa_ready(0) - yield from slave.slave_wd_ready(0) - yield from slave.slave_wr_wait() - yield from slave.slave_wr_respond(0) - yield Settle() - - yield from slave.slave_ra_wait() - for _ in range(2): - yield Tick() - yield from slave.slave_ra_ready(1) - yield from slave.slave_ra_verify(1, 1) + slave.slave_wa_ready(sim, 0) + slave.slave_wd_ready(sim, 0) + slave.slave_wr_get(sim) + await slave.slave_wr_respond(sim, 0) + + slave.slave_ra_get(sim) + await self.tick(sim, 2) + slave.slave_ra_ready(sim, 1) + slave.slave_ra_get_and_verify(sim, 1, 1) # wait for next rising edge - yield Tick() - yield Tick() + await sim.tick() - yield from slave.slave_ra_wait() - yield from slave.slave_ra_verify(2, 1) - yield from slave.slave_rd_wait() - yield from slave.slave_rd_respond(3, 1) - yield Settle() + slave.slave_ra_get(sim) + slave.slave_ra_get_and_verify(sim, 2, 1) + slave.slave_rd_get(sim) + await slave.slave_rd_respond(sim, 3, 1) - yield from slave.slave_rd_wait() - yield from slave.slave_rd_respond(4, 1) + await slave.slave_rd_wait(sim) + await slave.slave_rd_respond(sim, 4, 1) - def result_process(): - resp = yield from almt.read_data_response_adapter.call() + async def result_process(sim: TestbenchContext): + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 10 assert resp["resp"] == 0 - resp = yield from almt.read_data_response_adapter.call() + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 15 assert resp["resp"] == 0 - resp = yield from almt.read_data_response_adapter.call() + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 20 assert resp["resp"] == 0 - resp = yield from almt.read_data_response_adapter.call() + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 25 assert resp["resp"] == 0 - resp = yield from almt.write_response_response_adapter.call() + resp = await almt.write_response_response_adapter.call(sim) assert resp["resp"] == 1 - resp = yield from almt.write_response_response_adapter.call() + resp = await almt.write_response_response_adapter.call(sim) assert resp["resp"] == 1 - resp = yield from almt.write_response_response_adapter.call() + resp = await almt.write_response_response_adapter.call(sim) assert resp["resp"] == 1 - resp = yield from almt.write_response_response_adapter.call() + resp = await almt.write_response_response_adapter.call(sim) assert resp["resp"] == 0 for _ in range(5): - yield Tick() + await sim.tick() - resp = yield from almt.read_data_response_adapter.call() + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 3 assert resp["resp"] == 1 - resp = yield from almt.read_data_response_adapter.call() + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 4 assert resp["resp"] == 1 with self.run_simulation(almt) as sim: - sim.add_process(master_process) - sim.add_process(slave_process) - sim.add_process(result_process) + sim.add_testbench(master_process) + sim.add_testbench(slave_process) + sim.add_testbench(result_process) diff --git a/test/peripherals/test_wishbone.py b/test/peripherals/test_wishbone.py index 68b04c6a0..d5a19f19b 100644 --- a/test/peripherals/test_wishbone.py +++ b/test/peripherals/test_wishbone.py @@ -1,7 +1,9 @@ +from collections.abc import Iterable import random from collections import deque from amaranth.lib.wiring import connect +from amaranth_types import ValueLike from coreblocks.peripherals.wishbone import * @@ -14,50 +16,80 @@ class WishboneInterfaceWrapper: def __init__(self, wishbone_interface: WishboneInterface): self.wb = wishbone_interface - def master_set(self, addr, data, we): - yield self.wb.dat_w.eq(data) - yield self.wb.adr.eq(addr) - yield self.wb.we.eq(we) - yield self.wb.cyc.eq(1) - yield self.wb.stb.eq(1) + def master_set(self, sim: SimulatorContext, addr: int, data: int, we: int): + sim.set(self.wb.dat_w, data) + sim.set(self.wb.adr, addr) + sim.set(self.wb.we, we) + sim.set(self.wb.cyc, 1) + sim.set(self.wb.stb, 1) - def master_release(self, release_cyc=1): - yield self.wb.stb.eq(0) + def master_release(self, sim: SimulatorContext, release_cyc: bool = True): + sim.set(self.wb.stb, 0) if release_cyc: - yield self.wb.cyc.eq(0) - - def master_verify(self, exp_data=0): - assert (yield self.wb.ack) - assert (yield self.wb.dat_r) == exp_data - - def slave_wait(self): - while not ((yield self.wb.stb) and (yield self.wb.cyc)): - yield Tick() - - def slave_verify(self, exp_addr, exp_data, exp_we, exp_sel=0): - assert (yield self.wb.stb) and (yield self.wb.cyc) - - assert (yield self.wb.adr) == exp_addr - assert (yield self.wb.we) == exp_we - assert (yield self.wb.sel) == exp_sel + sim.set(self.wb.cyc, 0) + + async def slave_wait(self, sim: SimulatorContext): + *_, adr, we, sel, dat_w = ( + await sim.tick() + .sample(self.wb.adr, self.wb.we, self.wb.sel, self.wb.dat_w) + .until(self.wb.stb & self.wb.cyc) + ) + return adr, we, sel, dat_w + + async def slave_wait_and_verify( + self, sim: SimulatorContext, exp_addr: int, exp_data: int, exp_we: int, exp_sel: int = 0 + ): + adr, we, sel, dat_w = await self.slave_wait(sim) + + assert adr == exp_addr + assert we == exp_we + assert sel == exp_sel if exp_we: - assert (yield self.wb.dat_w) == exp_data - - def slave_respond(self, data, ack=1, err=0, rty=0): - assert (yield self.wb.stb) and (yield self.wb.cyc) - - yield self.wb.dat_r.eq(data) - yield self.wb.ack.eq(ack) - yield self.wb.err.eq(err) - yield self.wb.rty.eq(rty) - yield Tick() - yield self.wb.ack.eq(0) - yield self.wb.err.eq(0) - yield self.wb.rty.eq(0) - - def wait_ack(self): - while not ((yield self.wb.stb) and (yield self.wb.cyc) and (yield self.wb.ack)): - yield Tick() + assert dat_w == exp_data + + async def slave_tick_and_verify( + self, sim: SimulatorContext, exp_addr: int, exp_data: int, exp_we: int, exp_sel: int = 0 + ): + *_, adr, we, sel, dat_w, stb, cyc = await sim.tick().sample( + self.wb.adr, self.wb.we, self.wb.sel, self.wb.dat_w, self.wb.stb, self.wb.cyc + ) + assert stb and cyc + + assert adr == exp_addr + assert we == exp_we + assert sel == exp_sel + if exp_we: + assert dat_w == exp_data + + async def slave_respond( + self, + sim: SimulatorContext, + data: int, + ack: int = 1, + err: int = 0, + rty: int = 0, + sample: Iterable[ValueLike] = (), + ): + assert sim.get(self.wb.stb) and sim.get(self.wb.cyc) + + sim.set(self.wb.dat_r, data) + sim.set(self.wb.ack, ack) + sim.set(self.wb.err, err) + sim.set(self.wb.rty, rty) + ret = await sim.tick().sample(*sample) + sim.set(self.wb.ack, 0) + sim.set(self.wb.err, 0) + sim.set(self.wb.rty, 0) + return ret + + async def slave_respond_master_verify( + self, sim: SimulatorContext, master: WishboneInterface, data: int, ack: int = 1, err: int = 0, rty: int = 0 + ): + *_, ack, dat_r = await self.slave_respond(sim, data, ack, err, rty, sample=[master.ack, master.dat_r]) + assert ack and dat_r == data + + async def wait_ack(self, sim: SimulatorContext): + await sim.tick().until(self.wb.stb & self.wb.cyc & self.wb.ack) class TestWishboneMaster(TestCaseWithSimulator): @@ -75,71 +107,63 @@ def elaborate(self, platform): def test_manual(self): twbm = TestWishboneMaster.WishboneMasterTestModule() - def process(): + async def process(sim: TestbenchContext): # read request - yield from twbm.requestAdapter.call(addr=2, data=0, we=0, sel=1) + await twbm.requestAdapter.call(sim, addr=2, data=0, we=0, sel=1) # read request after delay - yield Tick() - yield Tick() - yield from twbm.requestAdapter.call(addr=1, data=0, we=0, sel=1) + await sim.tick() + await sim.tick() + await twbm.requestAdapter.call(sim, addr=1, data=0, we=0, sel=1) # write request - yield from twbm.requestAdapter.call(addr=3, data=5, we=1, sel=0) + await twbm.requestAdapter.call(sim, addr=3, data=5, we=1, sel=0) # RTY and ERR responese - yield from twbm.requestAdapter.call(addr=2, data=0, we=0, sel=0) - resp = yield from twbm.requestAdapter.call_try(addr=0, data=0, we=0, sel=0) + await twbm.requestAdapter.call(sim, addr=2, data=0, we=0, sel=0) + resp = await twbm.requestAdapter.call_try(sim, addr=0, data=0, we=0, sel=0) assert resp is None # verify cycle restart - def result_process(): - resp = yield from twbm.resultAdapter.call() + async def result_process(sim: TestbenchContext): + resp = await twbm.resultAdapter.call(sim) assert resp["data"] == 8 assert not resp["err"] - resp = yield from twbm.resultAdapter.call() + resp = await twbm.resultAdapter.call(sim) assert resp["data"] == 3 assert not resp["err"] - resp = yield from twbm.resultAdapter.call() + resp = await twbm.resultAdapter.call(sim) assert not resp["err"] - resp = yield from twbm.resultAdapter.call() + resp = await twbm.resultAdapter.call(sim) assert resp["data"] == 1 assert resp["err"] - def slave(): + async def slave(sim: TestbenchContext): wwb = WishboneInterfaceWrapper(twbm.wbm.wb_master) - yield from wwb.slave_wait() - yield from wwb.slave_verify(2, 0, 0, 1) - yield from wwb.slave_respond(8) - yield Settle() + await wwb.slave_wait_and_verify(sim, 2, 0, 0, 1) + await wwb.slave_respond(sim, 8) - yield from wwb.slave_wait() - yield from wwb.slave_verify(1, 0, 0, 1) - yield from wwb.slave_respond(3) - yield Settle() + await wwb.slave_wait_and_verify(sim, 1, 0, 0, 1) + await wwb.slave_respond(sim, 3) - yield # consecutive request - yield from wwb.slave_verify(3, 5, 1, 0) - yield from wwb.slave_respond(0) - yield Tick() + await wwb.slave_tick_and_verify(sim, 3, 5, 1, 0) + await wwb.slave_respond(sim, 0) + await sim.tick() - yield # consecutive request - yield from wwb.slave_verify(2, 0, 0, 0) - yield from wwb.slave_respond(1, ack=0, err=0, rty=1) - yield Settle() - assert not (yield wwb.wb.stb) + await wwb.slave_tick_and_verify(sim, 2, 0, 0, 0) + await wwb.slave_respond(sim, 1, ack=0, err=0, rty=1) + assert not sim.get(wwb.wb.stb) - yield from wwb.slave_wait() - yield from wwb.slave_verify(2, 0, 0, 0) - yield from wwb.slave_respond(1, ack=1, err=1, rty=0) + await wwb.slave_wait_and_verify(sim, 2, 0, 0, 0) + await wwb.slave_respond(sim, 1, ack=1, err=1, rty=0) with self.run_simulation(twbm) as sim: - sim.add_process(process) - sim.add_process(result_process) - sim.add_process(slave) + sim.add_testbench(process) + sim.add_testbench(result_process) + sim.add_testbench(slave) class TestWishboneMuxer(TestCaseWithSimulator): @@ -149,97 +173,80 @@ def test_manual(self): slaves = [WishboneInterfaceWrapper(slave) for slave in mux.slaves] wb_master = WishboneInterfaceWrapper(mux.master_wb) - def process(): + async def process(sim: TestbenchContext): # check full communiaction - yield from wb_master.master_set(2, 0, 1) - yield mux.sselTGA.eq(0b0001) - yield Tick() - yield from slaves[0].slave_verify(2, 0, 1) - assert not (yield slaves[1].wb.stb) - yield from slaves[0].slave_respond(4) - yield from wb_master.master_verify(4) - yield from wb_master.master_release(release_cyc=0) - yield Tick() + wb_master.master_set(sim, 2, 0, 1) + sim.set(mux.sselTGA, 0b0001) + await slaves[0].slave_tick_and_verify(sim, 2, 0, 1) + assert not sim.get(slaves[1].wb.stb) + await slaves[0].slave_respond_master_verify(sim, wb_master.wb, 4) + wb_master.master_release(sim, release_cyc=False) + await sim.tick() # select without releasing cyc (only on stb) - yield from wb_master.master_set(3, 0, 0) - yield mux.sselTGA.eq(0b0010) - yield Tick() - assert not (yield slaves[0].wb.stb) - yield from slaves[1].slave_verify(3, 0, 0) - yield from slaves[1].slave_respond(5) - yield from wb_master.master_verify(5) - yield from wb_master.master_release() - yield Tick() + wb_master.master_set(sim, 3, 0, 0) + sim.set(mux.sselTGA, 0b0010) + await slaves[1].slave_tick_and_verify(sim, 3, 0, 0) + assert not sim.get(slaves[0].wb.stb) + await slaves[1].slave_respond_master_verify(sim, wb_master.wb, 5) + wb_master.master_release(sim) + await sim.tick() # normal selection - yield from wb_master.master_set(6, 0, 0) - yield mux.sselTGA.eq(0b1000) - yield Tick() - yield from slaves[3].slave_verify(6, 0, 0) - yield from slaves[3].slave_respond(1) - yield from wb_master.master_verify(1) + wb_master.master_set(sim, 6, 0, 0) + sim.set(mux.sselTGA, 0b1000) + await slaves[3].slave_tick_and_verify(sim, 6, 0, 0) + await slaves[3].slave_respond_master_verify(sim, wb_master.wb, 1) with self.run_simulation(mux) as sim: - sim.add_process(process) + sim.add_testbench(process) -class TestWishboneAribiter(TestCaseWithSimulator): +class TestWishboneArbiter(TestCaseWithSimulator): def test_manual(self): arb = WishboneArbiter(WishboneParameters(), 2) slave = WishboneInterfaceWrapper(arb.slave_wb) masters = [WishboneInterfaceWrapper(master) for master in arb.masters] - def process(): - yield from masters[0].master_set(2, 3, 1) - yield from slave.slave_wait() - yield from slave.slave_verify(2, 3, 1) - yield from masters[1].master_set(1, 4, 1) - yield from slave.slave_respond(0) - - yield from masters[0].master_verify() - assert not (yield masters[1].wb.ack) - yield from masters[0].master_release() - yield Tick() + async def process(sim: TestbenchContext): + masters[0].master_set(sim, 2, 3, 1) + await slave.slave_wait_and_verify(sim, 2, 3, 1) + masters[1].master_set(sim, 1, 4, 1) + await slave.slave_respond_master_verify(sim, masters[0].wb, 0) + assert not sim.get(masters[1].wb.ack) + masters[0].master_release(sim) + await sim.tick() # check if bus is granted to next master if previous ends cycle - yield from slave.slave_wait() - yield from slave.slave_verify(1, 4, 1) - yield from slave.slave_respond(0) - yield from masters[1].master_verify() - assert not (yield masters[0].wb.ack) - yield from masters[1].master_release() - yield Tick() + await slave.slave_wait_and_verify(sim, 1, 4, 1) + await slave.slave_respond_master_verify(sim, masters[1].wb, 0) + assert not sim.get(masters[0].wb.ack) + masters[1].master_release(sim) + await sim.tick() # check round robin behaviour (2 masters requesting *2) - yield from masters[0].master_set(1, 0, 0) - yield from masters[1].master_set(2, 0, 0) - yield from slave.slave_wait() - yield from slave.slave_verify(1, 0, 0) - yield from slave.slave_respond(3) - yield from masters[0].master_verify(3) - yield from masters[0].master_release() - yield from masters[1].master_release() - yield Tick() - assert not (yield slave.wb.cyc) - - yield from masters[0].master_set(1, 0, 0) - yield from masters[1].master_set(2, 0, 0) - yield from slave.slave_wait() - yield from slave.slave_verify(2, 0, 0) - yield from slave.slave_respond(0) - yield from masters[1].master_verify() + masters[0].master_set(sim, 1, 0, 0) + masters[1].master_set(sim, 2, 0, 0) + await slave.slave_wait_and_verify(sim, 1, 0, 0) + await slave.slave_respond_master_verify(sim, masters[0].wb, 3) + masters[0].master_release(sim) + masters[1].master_release(sim) + await sim.tick() + assert not sim.get(slave.wb.cyc) + + masters[0].master_set(sim, 1, 0, 0) + masters[1].master_set(sim, 2, 0, 0) + await slave.slave_wait_and_verify(sim, 2, 0, 0) + await slave.slave_respond_master_verify(sim, masters[1].wb, 0) # check if releasing stb keeps grant - yield from masters[1].master_release(release_cyc=0) - yield Tick() - yield from masters[1].master_set(3, 0, 0) - yield from slave.slave_wait() - yield from slave.slave_verify(3, 0, 0) - yield from slave.slave_respond(0) - yield from masters[1].master_verify() + masters[1].master_release(sim, release_cyc=False) + await sim.tick() + masters[1].master_set(sim, 3, 0, 0) + await slave.slave_wait_and_verify(sim, 3, 0, 0) + await slave.slave_respond_master_verify(sim, masters[1].wb, 0) with self.run_simulation(arb) as sim: - sim.add_process(process) + sim.add_testbench(process) class TestPipelinedWishboneMaster(TestCaseWithSimulator): @@ -254,7 +261,7 @@ def test_randomized(self): wb_params = WishboneParameters() pwbm = SimpleTestCircuit(PipelinedWishboneMaster((wb_params))) - def request_process(): + async def request_process(sim: TestbenchContext): for _ in range(requests): request = { "addr": random.randint(0, 2**wb_params.addr_width - 1), @@ -263,49 +270,46 @@ def request_process(): "sel": random.randint(0, 2**wb_params.granularity - 1), } req_queue.appendleft(request) - yield from pwbm.request.call(request) + await pwbm.request.call(sim, request) - def verify_process(): + async def verify_process(sim: TestbenchContext): for _ in range(requests): - while random.random() < 0.8: - yield Tick() + await self.random_wait_geom(sim, 0.8) - result = yield from pwbm.result.call() + result = await pwbm.result.call(sim) cres = res_queue.pop() assert result["data"] == cres assert not result["err"] - def slave_process(): - yield Passive() - + async def slave_process(sim: TestbenchContext): wbw = pwbm._dut.wb - while True: - if (yield wbw.cyc) and (yield wbw.stb): - assert not (yield wbw.stall) + async for *_, cyc, stb, stall, adr, dat_w, we, sel in sim.tick().sample( + wbw.cyc, wbw.stb, wbw.stall, wbw.adr, wbw.dat_w, wbw.we, wbw.sel + ): + if cyc and stb: + assert not stall assert req_queue c_req = req_queue.pop() - assert (yield wbw.adr) == c_req["addr"] - assert (yield wbw.dat_w) == c_req["data"] - assert (yield wbw.we) == c_req["we"] - assert (yield wbw.sel) == c_req["sel"] + assert adr == c_req["addr"] + assert dat_w == c_req["data"] + assert we == c_req["we"] + assert sel == c_req["sel"] - slave_queue.appendleft((yield wbw.dat_w)) - res_queue.appendleft((yield wbw.dat_w)) + slave_queue.appendleft(dat_w) + res_queue.appendleft(dat_w) if slave_queue and random.random() < 0.4: - yield wbw.ack.eq(1) - yield wbw.dat_r.eq(slave_queue.pop()) + sim.set(wbw.ack, 1) + sim.set(wbw.dat_r, slave_queue.pop()) else: - yield wbw.ack.eq(0) - - yield wbw.stall.eq(random.random() < 0.3) + sim.set(wbw.ack, 0) - yield Tick() + sim.set(wbw.stall, random.random() < 0.3) with self.run_simulation(pwbm) as sim: - sim.add_process(request_process) - sim.add_process(verify_process) - sim.add_process(slave_process) + sim.add_testbench(request_process) + sim.add_testbench(verify_process) + sim.add_testbench(slave_process, background=True) class WishboneMemorySlaveCircuit(Elaboratable): @@ -341,11 +345,10 @@ def setup_method(self): def test_randomized(self): req_queue = deque() - wr_queue = deque() mem_state = [0] * self.memsize - def request_process(): + async def request_process(sim: TestbenchContext): for _ in range(self.iters): req = { "addr": random.randint(0, self.memsize - 1), @@ -354,41 +357,27 @@ def request_process(): "sel": random.randint(0, 2**self.sel_width - 1), } req_queue.appendleft(req) - wr_queue.appendleft(req) - while random.random() < 0.2: - yield Tick() - yield from self.m.request.call(req) + await self.random_wait_geom(sim, 0.2) + await self.m.request.call(sim, req) - def result_process(): + async def result_process(sim: TestbenchContext): for _ in range(self.iters): - while random.random() < 0.2: - yield Tick() - res = yield from self.m.result.call() + await self.random_wait_geom(sim, 0.2) + res = await self.m.result.call(sim) req = req_queue.pop() if not req["we"]: assert res["data"] == mem_state[req["addr"]] - - def write_process(): - wwb = WishboneInterfaceWrapper(self.m.mem_master.wb_master) - for _ in range(self.iters): - yield from wwb.wait_ack() - req = wr_queue.pop() - - if req["we"]: + else: for i in range(self.sel_width): if req["sel"] & (1 << i): granularity_mask = (2**self.wb_params.granularity - 1) << (i * self.wb_params.granularity) mem_state[req["addr"]] &= ~granularity_mask mem_state[req["addr"]] |= req["data"] & granularity_mask - - yield Tick() - - if req["we"]: - assert (yield Value.cast(self.m.mem_slave.mem.data[req["addr"]])) == mem_state[req["addr"]] + val = sim.get(Value.cast(self.m.mem_slave.mem.data[req["addr"]])) + assert val == mem_state[req["addr"]] with self.run_simulation(self.m, max_cycles=3000) as sim: - sim.add_process(request_process) - sim.add_process(result_process) - sim.add_process(write_process) + sim.add_testbench(request_process) + sim.add_testbench(result_process) diff --git a/test/priv/traps/test_exception.py b/test/priv/traps/test_exception.py index 22ebb8b5e..824d892ba 100644 --- a/test/priv/traps/test_exception.py +++ b/test/priv/traps/test_exception.py @@ -38,16 +38,16 @@ def test_randomized(self): self.dut = SimpleTestCircuit( ExceptionInformationRegister( self.gen_params, self.rob_idx_mock.adapter.iface, self.fetch_stall_mock.adapter.iface - ) + ), ) m = ModuleConnector(self.dut, rob_idx_mock=self.rob_idx_mock, fetch_stall_mock=self.fetch_stall_mock) self.rob_id = 0 - def process_test(): + async def process_test(sim: TestbenchContext): saved_entry = None - yield from self.fetch_stall_mock.enable() + self.fetch_stall_mock.enable(sim) for _ in range(self.cycles): self.rob_id = random.randint(0, self.rob_max) @@ -61,12 +61,13 @@ def process_test(): report_arg = {"cause": cause, "rob_id": report_rob, "pc": report_pc, "mtval": report_mtval} expected = report_arg if self.should_update(report_arg, saved_entry, self.rob_id) else saved_entry - yield from self.dut.report.call(report_arg) - yield # additional FIFO delay + await self.dut.report.call(sim, report_arg) + # additional FIFO delay + *_, fetch_stall_mock_done = await self.fetch_stall_mock.sample_outputs_done(sim) - assert (yield from self.fetch_stall_mock.done()) + assert fetch_stall_mock_done - new_state = yield from self.dut.get.call() + new_state = data_const_to_dict(await self.dut.get.call(sim)) assert new_state == expected | {"valid": 1} # type: ignore @@ -77,4 +78,4 @@ def process_rob_idx_mock(): return {"start": self.rob_id, "end": 0} with self.run_simulation(m) as sim: - sim.add_process(process_test) + sim.add_testbench(process_test) diff --git a/test/regression/pysim.py b/test/regression/pysim.py index a21b293fe..ee8aa5990 100644 --- a/test/regression/pysim.py +++ b/test/regression/pysim.py @@ -2,22 +2,22 @@ import os import logging -from amaranth.sim import Passive, Settle, Tick from amaranth.utils import exact_log2 from amaranth import * from transactron.core.keys import TransactionManagerKey +from transactron.profiler import Profile +from transactron.testing.tick_count import make_tick_count_process from .memory import * from .common import SimulationBackend, SimulationExecutionResult from transactron.testing import ( PysimSimulator, - TestGen, profiler_process, - Profile, make_logging_process, parse_logging_level, + TestbenchContext, ) from transactron.utils.dependencies import DependencyContext, DependencyManager from transactron.lib.metrics import HardwareMetricsManager @@ -43,22 +43,20 @@ def __init__(self, traces_file: Optional[str] = None): def _wishbone_slave( self, mem_model: CoreMemoryModel, wb_ctrl: WishboneInterfaceWrapper, is_instr_bus: bool, delay: int = 0 ): - def f(): - yield Passive() - + async def f(sim: TestbenchContext): while True: - yield from wb_ctrl.slave_wait() + await wb_ctrl.slave_wait(sim) word_width_bytes = self.gp.isa.xlen // 8 # Wishbone is addressing words, so we need to shift it a bit to get the real address. - addr = (yield wb_ctrl.wb.adr) << exact_log2(word_width_bytes) - sel = yield wb_ctrl.wb.sel - dat_w = yield wb_ctrl.wb.dat_w + addr = sim.get(wb_ctrl.wb.adr) << exact_log2(word_width_bytes) + sel = sim.get(wb_ctrl.wb.sel) + dat_w = sim.get(wb_ctrl.wb.dat_w) resp_data = 0 - if (yield wb_ctrl.wb.we): + if sim.get(wb_ctrl.wb.we): resp = mem_model.write( WriteRequest(addr=addr, data=dat_w, byte_count=word_width_bytes, byte_sel=sel) ) @@ -83,21 +81,19 @@ def f(): rty = 1 for _ in range(delay): - yield Tick() - - yield from wb_ctrl.slave_respond(resp_data, ack=ack, err=err, rty=rty) + await sim.tick() - yield Settle() + await wb_ctrl.slave_respond(sim, resp_data, ack=ack, err=err, rty=rty) return f - def _waiter(self, on_finish: Callable[[], TestGen[None]]): - def f(): + def _waiter(self, on_finish: Callable[[TestbenchContext], None]): + async def f(sim: TestbenchContext): while self.running: self.cycle_cnt += 1 - yield Tick() + await sim.tick() - yield from on_finish() + on_finish(sim) return f @@ -141,13 +137,14 @@ async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int = 5000) -> S self.cycle_cnt = 0 sim = PysimSimulator(core, max_cycles=timeout_cycles, traces_file=self.traces_file) - sim.add_process(self._wishbone_slave(mem_model, wb_instr_ctrl, is_instr_bus=True)) - sim.add_process(self._wishbone_slave(mem_model, wb_data_ctrl, is_instr_bus=False)) + sim.add_testbench(self._wishbone_slave(mem_model, wb_instr_ctrl, is_instr_bus=True), background=True) + sim.add_testbench(self._wishbone_slave(mem_model, wb_data_ctrl, is_instr_bus=False), background=True) def on_error(): raise RuntimeError("Simulation finished due to an error") sim.add_process(make_logging_process(self.log_level, self.log_filter, on_error)) + sim.add_process(make_tick_count_process()) # This enables logging in benchmarks. TODO: after unifying regression testing, remove. logging.basicConfig() @@ -161,17 +158,17 @@ def on_error(): metric_values: dict[str, dict[str, int]] = {} - def on_sim_finish(): + def on_sim_finish(sim: TestbenchContext): # Collect metric values before we finish the simulation for metric_name, metric in self.metrics_manager.get_metrics().items(): metric = self.metrics_manager.get_metrics()[metric_name] metric_values[metric_name] = {} for reg_name in metric.regs: - metric_values[metric_name][reg_name] = yield self.metrics_manager.get_register_value( - metric_name, reg_name + metric_values[metric_name][reg_name] = sim.get( + self.metrics_manager.get_register_value(metric_name, reg_name) ) - sim.add_process(self._waiter(on_finish=on_sim_finish)) + sim.add_testbench(self._waiter(on_finish=on_sim_finish)) success = sim.run() self.pretty_dump_metrics(metric_values) diff --git a/test/scheduler/test_rs_selection.py b/test/scheduler/test_rs_selection.py index d00ac64f3..9a7e7d48b 100644 --- a/test/scheduler/test_rs_selection.py +++ b/test/scheduler/test_rs_selection.py @@ -2,7 +2,6 @@ import random from amaranth import * -from amaranth.sim import Settle, Passive from coreblocks.params import GenParams from coreblocks.interface.layouts import RSLayouts, SchedulerLayouts @@ -11,7 +10,9 @@ from coreblocks.params.configurations import test_core_config from coreblocks.scheduler.scheduler import RSSelection from transactron.lib import FIFO, Adapter, AdapterTrans -from transactron.testing import TestCaseWithSimulator, TestbenchIO +from transactron.testing import TestCaseWithSimulator, TestbenchIO, TestbenchContext +from transactron.testing.functions import data_const_to_dict +from transactron.testing.method_mock import MethodMock, def_method_mock _rs1_optypes = {OpType.ARITHMETIC, OpType.COMPARE} _rs2_optypes = {OpType.LOGIC, OpType.COMPARE} @@ -52,12 +53,12 @@ class TestRSSelect(TestCaseWithSimulator): def setup_method(self): self.gen_params = GenParams(test_core_config) self.m = RSSelector(self.gen_params) - self.expected_out = deque() - self.instr_in = deque() + self.expected_out: deque[dict] = deque() + self.instr_in: deque[dict] = deque() random.seed(1789) def create_instr_input_process(self, instr_count: int, optypes: set[OpType], random_wait: int = 0): - def process(): + async def process(sim: TestbenchContext): for i in range(instr_count): rp_dst = random.randrange(self.gen_params.phys_regs_bits) rp_s1 = random.randrange(self.gen_params.phys_regs_bits) @@ -91,41 +92,36 @@ def process(): } self.instr_in.append(instr) - yield from self.m.instr_in.call(instr) - yield from self.random_wait(random_wait) + await self.m.instr_in.call(sim, instr) + await self.random_wait(sim, random_wait) return process - def create_rs_alloc_process(self, io: TestbenchIO, rs_id: int, rs_optypes: set[OpType], random_wait: int = 0): - def mock(): + def create_rs_alloc_process(self, io: TestbenchIO, rs_id: int, rs_optypes: set[OpType], enable_prob: float = 1): + @def_method_mock(lambda: io, enable=lambda: random.random() <= enable_prob) + def process(): random_entry = random.randrange(self.gen_params.max_rs_entries) - expected = self.instr_in.popleft() - assert expected["exec_fn"]["op_type"] in rs_optypes - expected["rs_entry_id"] = random_entry - expected["rs_selected"] = rs_id - self.expected_out.append(expected) - return {"rs_entry_id": random_entry} + @MethodMock.effect + def eff(): + expected = self.instr_in.popleft() + assert expected["exec_fn"]["op_type"] in rs_optypes + expected["rs_entry_id"] = random_entry + expected["rs_selected"] = rs_id + self.expected_out.append(expected) - def process(): - yield Passive() - while True: - yield from io.enable() - yield from io.method_handle(mock) - yield from io.disable() - yield from self.random_wait(random_wait) + return {"rs_entry_id": random_entry} - return process + return process() def create_output_process(self, instr_count: int, random_wait: int = 0): - def process(): + async def process(sim: TestbenchContext): for _ in range(instr_count): - result = yield from self.m.instr_out.call() + result = await self.m.instr_out.call(sim) outputs = self.expected_out.popleft() - yield from self.random_wait(random_wait) - yield Settle() - assert result == outputs + await self.random_wait(sim, random_wait) + assert data_const_to_dict(result) == outputs return process @@ -135,10 +131,10 @@ def test_base_functionality(self): """ with self.run_simulation(self.m, max_cycles=1500) as sim: - sim.add_process(self.create_instr_input_process(100, _rs1_optypes.union(_rs2_optypes))) - sim.add_process(self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes)) - sim.add_process(self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes)) - sim.add_process(self.create_output_process(100)) + sim.add_testbench(self.create_instr_input_process(100, _rs1_optypes.union(_rs2_optypes))) + self.add_mock(sim, self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes)) + self.add_mock(sim, self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes)) + sim.add_testbench(self.create_output_process(100)) def test_only_rs1(self): """ @@ -147,9 +143,9 @@ def test_only_rs1(self): """ with self.run_simulation(self.m, max_cycles=1500) as sim: - sim.add_process(self.create_instr_input_process(100, _rs1_optypes.intersection(_rs2_optypes))) - sim.add_process(self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes)) - sim.add_process(self.create_output_process(100)) + sim.add_testbench(self.create_instr_input_process(100, _rs1_optypes.intersection(_rs2_optypes))) + self.add_mock(sim, self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes)) + sim.add_testbench(self.create_output_process(100)) def test_only_rs2(self): """ @@ -158,9 +154,9 @@ def test_only_rs2(self): """ with self.run_simulation(self.m, max_cycles=1500) as sim: - sim.add_process(self.create_instr_input_process(100, _rs1_optypes.intersection(_rs2_optypes))) - sim.add_process(self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes)) - sim.add_process(self.create_output_process(100)) + sim.add_testbench(self.create_instr_input_process(100, _rs1_optypes.intersection(_rs2_optypes))) + self.add_mock(sim, self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes)) + sim.add_testbench(self.create_output_process(100)) def test_delays(self): """ @@ -169,11 +165,11 @@ def test_delays(self): """ with self.run_simulation(self.m, max_cycles=5000) as sim: - sim.add_process(self.create_instr_input_process(300, _rs1_optypes.union(_rs2_optypes), random_wait=4)) - sim.add_process( - self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes, random_wait=12) + sim.add_testbench(self.create_instr_input_process(300, _rs1_optypes.union(_rs2_optypes), random_wait=4)) + self.add_mock( + sim, self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes, enable_prob=0.1) ) - sim.add_process( - self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes, random_wait=12) + self.add_mock( + sim, self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes, enable_prob=0.1) ) - sim.add_process(self.create_output_process(300, random_wait=12)) + sim.add_testbench(self.create_output_process(300, random_wait=12)) diff --git a/test/scheduler/test_scheduler.py b/test/scheduler/test_scheduler.py index 293ce201f..9c54c043b 100644 --- a/test/scheduler/test_scheduler.py +++ b/test/scheduler/test_scheduler.py @@ -3,15 +3,15 @@ from collections import namedtuple, deque from typing import Callable, Optional, Iterable from amaranth import * -from amaranth.lib.data import View -from amaranth.sim import Settle, Tick from parameterized import parameterized_class from coreblocks.interface.keys import CoreStateKey -from coreblocks.interface.layouts import ROBLayouts, RetirementLayouts +from coreblocks.interface.layouts import RetirementLayouts from coreblocks.func_blocks.fu.common.rs_func_block import RSBlockComponent from transactron.core import Method from transactron.lib import FIFO, AdapterTrans, Adapter +from transactron.testing.functions import MethodData, data_const_to_dict +from transactron.testing.method_mock import MethodMock from transactron.utils.dependencies import DependencyContext from coreblocks.scheduler.scheduler import Scheduler from coreblocks.core_structs.rf import RegisterFile @@ -22,7 +22,7 @@ from coreblocks.params.configurations import test_core_config from coreblocks.core_structs.rob import ReorderBuffer from coreblocks.func_blocks.interface.func_protocols import FuncBlock -from transactron.testing import RecordIntDict, TestCaseWithSimulator, TestGen, TestbenchIO, def_method_mock +from transactron.testing import TestCaseWithSimulator, TestbenchIO, def_method_mock, TestbenchContext class SchedulerTestCircuit(Elaboratable): @@ -157,7 +157,7 @@ def free_phys_reg(self, reg_id): self.free_regs_queue.append({"reg_id": reg_id}) self.expected_phys_reg_queue.append(reg_id) - def queue_gather(self, queues: Iterable[deque]): + async def queue_gather(self, sim: TestbenchContext, queues: Iterable[deque]): # Iterate over all 'queues' and take one element from each, gathering # all key-value pairs into 'item'. item = {} @@ -166,6 +166,7 @@ def queue_gather(self, queues: Iterable[deque]): # retry until we get an element while partial_item is None: # get element from one queue + await sim.delay(1e-9) if q: partial_item = q.popleft() # None signals to end the process @@ -173,7 +174,7 @@ def queue_gather(self, queues: Iterable[deque]): return None else: # if no element available, wait and retry on the next clock cycle - yield Tick() + await sim.tick() # merge queue element with all previous ones (dict merge) item = item | partial_item @@ -185,7 +186,7 @@ def make_queue_process( io: TestbenchIO, input_queues: Optional[Iterable[deque]] = None, output_queues: Optional[Iterable[deque]] = None, - check: Optional[Callable[[RecordIntDict, RecordIntDict], TestGen[None]]] = None, + check: Optional[Callable[[TestbenchContext, MethodData, dict], None]] = None, always_enable: bool = False, ): """Create queue gather-and-test process @@ -235,31 +236,30 @@ def make_queue_process( If neither `input_queues` nor `output_queues` are supplied. """ - def queue_process(): + async def queue_process(sim: TestbenchContext): if always_enable: - yield from io.enable() + io.enable(sim) while True: inputs = {} outputs = {} # gather items from both queues if input_queues is not None: - inputs = yield from self.queue_gather(input_queues) + inputs = await self.queue_gather(sim, input_queues) if output_queues is not None: - outputs = yield from self.queue_gather(output_queues) + outputs = await self.queue_gather(sim, output_queues) # Check if queues signalled to end the process if inputs is None or outputs is None: return - result = yield from io.call(inputs) + result = await io.call(sim, inputs) if always_enable: - yield from io.enable() + io.enable(sim) # this could possibly be extended to automatically compare 'results' and # 'outputs' if check is None but that needs some dict deepcompare if check is not None: - yield Settle() - yield from check(result, outputs) + check(sim, result, outputs) if output_queues is None and input_queues is None: raise ValueError("Either output_queues or input_queues must be supplied") @@ -267,44 +267,39 @@ def queue_process(): return queue_process def make_output_process(self, io: TestbenchIO, output_queues: Iterable[deque]): - def check(got, expected): - rl_dst = yield View( - self.gen_params.get(ROBLayouts).data_layout, - C( - (yield Value.cast(self.m.rob.data.data[got["rs_data"]["rob_id"]])), - self.gen_params.get(ROBLayouts).data_layout.size, - ), - ).rl_dst + def check(sim: TestbenchContext, got: MethodData, expected: dict): + # TODO: better stubs for Memory? + rl_dst = sim.get(self.m.rob.data.data[got.rs_data.rob_id].rl_dst) # type: ignore s1 = self.rf_state[expected["rp_s1"]] s2 = self.rf_state[expected["rp_s2"]] # if source operand register ids are 0 then we already have values - assert got["rs_data"]["rp_s1"] == (expected["rp_s1"] if not s1.valid else 0) - assert got["rs_data"]["rp_s2"] == (expected["rp_s2"] if not s2.valid else 0) - assert got["rs_data"]["rp_dst"] == expected["rp_dst"] - assert got["rs_data"]["exec_fn"] == expected["exec_fn"] - assert got["rs_entry_id"] == expected["rs_entry_id"] - assert got["rs_data"]["s1_val"] == (s1.value if s1.valid else 0) - assert got["rs_data"]["s2_val"] == (s2.value if s2.valid else 0) + assert got.rs_data.rp_s1 == (expected["rp_s1"] if not s1.valid else 0) + assert got.rs_data.rp_s2 == (expected["rp_s2"] if not s2.valid else 0) + assert got.rs_data.rp_dst == expected["rp_dst"] + assert data_const_to_dict(got.rs_data.exec_fn) == expected["exec_fn"] + assert got.rs_entry_id == expected["rs_entry_id"] + assert got.rs_data.s1_val == (s1.value if s1.valid else 0) + assert got.rs_data.s2_val == (s2.value if s2.valid else 0) assert rl_dst == expected["rl_dst"] # recycle physical register number - if got["rs_data"]["rp_dst"] != 0: - self.free_phys_reg(got["rs_data"]["rp_dst"]) + if got.rs_data.rp_dst != 0: + self.free_phys_reg(got.rs_data.rp_dst) # recycle ROB entry - self.free_ROB_entries_queue.append({"rob_id": got["rs_data"]["rob_id"]}) + self.free_ROB_entries_queue.append({"rob_id": got.rs_data.rob_id}) return self.make_queue_process(io=io, output_queues=output_queues, check=check, always_enable=True) def test_randomized(self): - def instr_input_process(): - yield from self.m.rob_retire.enable() + async def instr_input_process(sim: TestbenchContext): + self.m.rob_retire.enable(sim) # set up RF to reflect our static rf_state reference lookup table for i in range(2**self.gen_params.phys_regs_bits - 1): - yield from self.m.rf_write.call(reg_id=i, reg_val=self.rf_state[i].value) + await self.m.rf_write.call(sim, reg_id=i, reg_val=self.rf_state[i].value) if not self.rf_state[i].valid: - yield from self.m.rf_free.call(reg_id=i) + await self.m.rf_free.call(sim, reg_id=i) op_types_set = set() for rs in self.optype_sets: @@ -338,7 +333,8 @@ def instr_input_process(): ) self.current_RAT[rl_dst] = rp_dst - yield from self.m.instr_inp.call( + await self.m.instr_inp.call( + sim, { "exec_fn": { "op_type": op_type, @@ -351,7 +347,7 @@ def instr_input_process(): "rl_dst": rl_dst, }, "imm": immediate, - } + }, ) # Terminate other processes self.expected_rename_queue.append(None) @@ -362,19 +358,22 @@ def rs_alloc_process(io: TestbenchIO, rs_id: int): @def_method_mock(lambda: io) def process(): random_entry = random.randrange(self.gen_params.max_rs_entries) - expected = self.expected_rename_queue.popleft() - expected["rs_entry_id"] = random_entry - self.expected_rs_entry_queue[rs_id].append(expected) - # if last instruction was allocated stop simulation - self.allocated_instr_count += 1 - if self.allocated_instr_count == self.instr_count: - for i in range(self.rs_count): - self.expected_rs_entry_queue[i].append(None) + @MethodMock.effect + def eff(): + expected = self.expected_rename_queue.popleft() + expected["rs_entry_id"] = random_entry + self.expected_rs_entry_queue[rs_id].append(expected) + + # if last instruction was allocated stop simulation + self.allocated_instr_count += 1 + if self.allocated_instr_count == self.instr_count: + for i in range(self.rs_count): + self.expected_rs_entry_queue[i].append(None) return {"rs_entry_id": random_entry} - return process + return process() @def_method_mock(lambda: self.m.core_state) def core_state_mock(): @@ -383,10 +382,10 @@ def core_state_mock(): with self.run_simulation(self.m, max_cycles=1500) as sim: for i in range(self.rs_count): - sim.add_process( + sim.add_testbench( self.make_output_process(io=self.m.rs_insert[i], output_queues=[self.expected_rs_entry_queue[i]]) ) - sim.add_process(rs_alloc_process(self.m.rs_alloc[i], i)) - sim.add_process(self.make_queue_process(io=self.m.rob_done, input_queues=[self.free_ROB_entries_queue])) - sim.add_process(self.make_queue_process(io=self.m.free_rf_inp, input_queues=[self.free_regs_queue])) - sim.add_process(instr_input_process) + self.add_mock(sim, rs_alloc_process(self.m.rs_alloc[i], i)) + sim.add_testbench(self.make_queue_process(io=self.m.rob_done, input_queues=[self.free_ROB_entries_queue])) + sim.add_testbench(self.make_queue_process(io=self.m.free_rf_inp, input_queues=[self.free_regs_queue])) + sim.add_testbench(instr_input_process) diff --git a/test/scheduler/test_wakeup_select.py b/test/scheduler/test_wakeup_select.py index b51af3cd3..cd34de905 100644 --- a/test/scheduler/test_wakeup_select.py +++ b/test/scheduler/test_wakeup_select.py @@ -1,7 +1,6 @@ from typing import Optional, cast from amaranth import * from amaranth.lib.data import StructLayout -from amaranth.sim import Settle, Tick from collections import deque from enum import Enum @@ -16,7 +15,8 @@ from transactron.lib import Adapter from coreblocks.scheduler.wakeup_select import * -from transactron.testing import RecordIntDict, TestCaseWithSimulator, TestbenchIO +from transactron.testing import RecordIntDict, TestCaseWithSimulator, TestbenchIO, TestbenchContext +from transactron.testing.functions import data_const_to_dict class WakeupTestCircuit(Elaboratable): @@ -76,46 +76,38 @@ def maybe_insert(self, rs: list[Optional[RecordIntDict]]): empty_idx -= 1 return 0 - def process(self): + async def process(self, sim: TestbenchContext): inserted_count = 0 issued_count = 0 rs: list[Optional[RecordIntDict]] = [None for _ in range(self.m.gen_params.max_rs_entries)] - yield from self.m.take_row_mock.enable() - yield from self.m.issue_mock.enable() - yield Settle() + self.m.take_row_mock.enable(sim) + self.m.issue_mock.enable(sim) for _ in range(self.cycles): inserted_count += self.maybe_insert(rs) - ready = Cat(entry is not None for entry in rs) + ready = Const.cast(Cat(entry is not None for entry in rs)) - yield from self.m.ready_mock.call_init(ready_list=ready) - if any(entry is not None for entry in rs): - yield from self.m.ready_mock.enable() - else: - yield from self.m.ready_mock.disable() + self.m.ready_mock.call_init(sim, ready_list=ready) + self.m.ready_mock.set_enable(sim, any(entry is not None for entry in rs)) - yield Settle() - - take_position = yield from self.m.take_row_mock.call_result() + take_position = self.m.take_row_mock.get_call_result(sim) if take_position is not None: take_position = cast(int, take_position["rs_entry_id"]) entry = rs[take_position] assert entry is not None self.taken.append(entry) - yield from self.m.take_row_mock.call_init(entry) + self.m.take_row_mock.call_init(sim, entry) rs[take_position] = None - yield Settle() - - issued = yield from self.m.issue_mock.call_result() + issued = self.m.issue_mock.get_call_result(sim) if issued is not None: - assert issued == self.taken.popleft() + assert data_const_to_dict(issued) == self.taken.popleft() issued_count += 1 - yield Tick() + await sim.tick() assert inserted_count != 0 assert inserted_count == issued_count def test(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) diff --git a/test/test_core.py b/test/test_core.py index c1875616d..39e2bfef3 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -1,10 +1,13 @@ +from collections.abc import Callable +from typing import Any from amaranth import * from amaranth.lib.wiring import connect -from amaranth.sim import Passive, Tick +from amaranth_types import ValueLike +from transactron.testing.tick_count import TicksKey from transactron.utils import align_to_power_of_two -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, ProcessContext, TestbenchContext from coreblocks.arch.isa_consts import PrivilegeLevel from coreblocks.core import Core @@ -18,6 +21,8 @@ import tempfile from parameterized import parameterized_class +from transactron.utils.dependencies import DependencyContext + class CoreTestElaboratable(Elaboratable): def __init__(self, gen_params: GenParams, instr_mem: list[int] = [0], data_mem: list[int] = []): @@ -58,11 +63,11 @@ class TestCoreBase(TestCaseWithSimulator): gen_params: GenParams m: CoreTestElaboratable - def get_phys_reg_rrat(self, reg_id): - return (yield self.m.core.RRAT.entries[reg_id]) + def get_phys_reg_rrat(self, sim: TestbenchContext, reg_id): + return sim.get(self.m.core.RRAT.entries[reg_id]) - def get_arch_reg_val(self, reg_id): - return (yield self.m.core.RF.entries[(yield from self.get_phys_reg_rrat(reg_id))].reg_val) + def get_arch_reg_val(self, sim: TestbenchContext, reg_id): + return sim.get(self.m.core.RF.entries[(self.get_phys_reg_rrat(sim, reg_id))].reg_val) class TestCoreAsmSourceBase(TestCoreBase): @@ -148,12 +153,11 @@ class TestCoreBasicAsm(TestCoreAsmSourceBase): expected_regvals: dict[int, int] configuration: CoreConfiguration - def run_and_check(self): - for _ in range(self.cycle_count): - yield Tick() + async def run_and_check(self, sim: TestbenchContext): + await self.tick(sim, self.cycle_count) for reg_id, val in self.expected_regvals.items(): - assert (yield from self.get_arch_reg_val(reg_id)) == val + assert self.get_arch_reg_val(sim, reg_id) == val def test_asm_source(self): self.gen_params = GenParams(self.configuration) @@ -166,7 +170,7 @@ def test_asm_source(self): self.m = CoreTestElaboratable(self.gen_params, instr_mem=bin_src["text"], data_mem=bin_src["data"]) with self.run_simulation(self.m) as sim: - sim.add_process(self.run_and_check) + sim.add_testbench(self.run_and_check) # test interrupts with varying triggering frequency (parametrizable amount of cycles between @@ -208,40 +212,37 @@ def setup_method(self): self.gen_params = GenParams(self.configuration) random.seed(1500100900) - def clear_level_interrupt_procsess(self): - yield Passive() - while True: - while (yield self.m.core.csr_generic.csr_coreblocks_test.value) == 0: - yield Tick() + async def clear_level_interrupt_process(self, sim: ProcessContext): + async for *_, value in sim.tick().sample(self.m.core.csr_generic.csr_coreblocks_test.value): + if value == 0: + continue - if (yield self.m.core.csr_generic.csr_coreblocks_test.value) == 2: + if value == 2: assert False, "`fail` called" - yield self.m.core.csr_generic.csr_coreblocks_test.value.eq(0) - yield self.m.interrupt_level.eq(0) - yield Tick() + sim.set(self.m.core.csr_generic.csr_coreblocks_test.value, 0) + sim.set(self.m.interrupt_level, 0) - def run_with_interrupt_process(self): + async def run_with_interrupt_process(self, sim: TestbenchContext): main_cycles = 0 int_count = 0 handler_count = 0 # wait for interrupt enable - while (yield self.m.core.interrupt_controller.mstatus_mie.value) == 0: - yield Tick() + await sim.tick().until(self.m.core.interrupt_controller.mstatus_mie.value) - def do_interrupt(): + async def do_interrupt(): count = 0 trig = random.randint(1, 3) - mie = (yield self.m.core.interrupt_controller.mie.value) >> 16 + mie = sim.get(self.m.core.interrupt_controller.mie.value) >> 16 if mie != 0b11 or trig & 1 or self.edge_only: - yield self.m.interrupt_edge.eq(1) + sim.set(self.m.interrupt_edge, 1) count += 1 - if (mie != 0b11 or trig & 2) and (yield self.m.interrupt_level) == 0 and not self.edge_only: - yield self.m.interrupt_level.eq(1) + if (mie != 0b11 or trig & 2) and sim.get(self.m.interrupt_level) == 0 and not self.edge_only: + sim.set(self.m.interrupt_level, 1) count += 1 - yield Tick() - yield self.m.interrupt_edge.eq(0) + await sim.tick() + sim.set(self.m.interrupt_edge, 0) return count early_interrupt = False @@ -250,40 +251,35 @@ def do_interrupt(): # run main code for some semi-random amount of cycles c = random.randrange(self.lo, self.hi) main_cycles += c - yield from self.tick(c) + await self.tick(sim, c) # trigger an interrupt - int_count += yield from do_interrupt() + int_count += await do_interrupt() # wait for the interrupt to get registered - while (yield self.m.core.interrupt_controller.mstatus_mie.value) == 1: - yield Tick() + await sim.tick().until(self.m.core.interrupt_controller.mstatus_mie.value != 1) # trigger interrupt during execution of ISR handler (blocked-pending) with some chance early_interrupt = random.random() < 0.4 if early_interrupt: # wait until interrupts are cleared, so it won't be missed - while (yield self.m.core.interrupt_controller.mip.value) != 0: - yield Tick() - - assert (yield from self.get_arch_reg_val(30)) == int_count + await sim.tick().until(self.m.core.interrupt_controller.mip.value == 0) + assert self.get_arch_reg_val(sim, 30) == int_count - int_count += yield from do_interrupt() + int_count += await do_interrupt() else: - while (yield self.m.core.interrupt_controller.mip.value) != 0: - yield Tick() - assert (yield from self.get_arch_reg_val(30)) == int_count + await sim.tick().until(self.m.core.interrupt_controller.mip.value == 0) + assert self.get_arch_reg_val(sim, 30) == int_count handler_count += 1 # wait until ISR returns - while (yield self.m.core.interrupt_controller.mstatus_mie.value) == 0: - yield Tick() + await sim.tick().until(self.m.core.interrupt_controller.mstatus_mie.value != 0) - assert (yield from self.get_arch_reg_val(30)) == int_count - assert (yield from self.get_arch_reg_val(27)) == handler_count + assert self.get_arch_reg_val(sim, 30) == int_count + assert self.get_arch_reg_val(sim, 27) == handler_count for reg_id, val in self.expected_regvals.items(): - assert (yield from self.get_arch_reg_val(reg_id)) == val + assert self.get_arch_reg_val(sim, reg_id) == val def test_interrupted_prog(self): bin_src = self.prepare_source(self.source_file) @@ -291,14 +287,14 @@ def test_interrupted_prog(self): bin_src["data"][self.reg_init_mem_offset // 4 + reg_id] = val self.m = CoreTestElaboratable(self.gen_params, instr_mem=bin_src["text"], data_mem=bin_src["data"]) with self.run_simulation(self.m) as sim: - sim.add_process(self.run_with_interrupt_process) - sim.add_process(self.clear_level_interrupt_procsess) + sim.add_testbench(self.run_with_interrupt_process) + sim.add_process(self.clear_level_interrupt_process) @parameterized_class( ("source_file", "cycle_count", "expected_regvals", "always_mmode"), [ - ("user_mode.asm", 1000, {4: 5}, False), + ("user_mode.asm", 1100, {4: 5}, False), ("wfi_no_mie.asm", 250, {8: 8}, True), # only using level enable ], ) @@ -315,48 +311,45 @@ def setup_method(self): self.gen_params = GenParams(self.configuration) random.seed(161453) - def run_with_interrupt_process(self): - cycles = 0 + async def run_with_interrupt_process(self, sim: TestbenchContext): + ticks = DependencyContext.get().get_dependency(TicksKey()) # wait for interrupt enable - while (yield self.m.core.interrupt_controller.mie.value) == 0 and cycles < self.cycle_count: - cycles += 1 - yield Tick() - yield from self.random_wait(5) + async def wait_or_timeout(cond: ValueLike, pred: Callable[[Any], bool]): + async for *_, value in sim.tick().sample(cond): + if pred(value) or sim.get(ticks) >= self.cycle_count: + break - while cycles < self.cycle_count: - yield self.m.interrupt_level.eq(1) - cycles += 1 - yield Tick() + await wait_or_timeout(self.m.core.interrupt_controller.mie.value, lambda value: value != 0) + await self.random_wait(sim, 5) + + while sim.get(ticks) < self.cycle_count: + sim.set(self.m.interrupt_level, 1) if self.always_mmode: # if test happens only in m_mode, just enable fixed interrupt + await sim.tick() continue # wait for the interrupt to get registered - while ( - yield self.m.core.csr_generic.m_mode.priv_mode.value - ) != PrivilegeLevel.MACHINE and cycles < self.cycle_count: - cycles += 1 - yield Tick() + await wait_or_timeout( + self.m.core.csr_generic.m_mode.priv_mode.value, lambda value: value == PrivilegeLevel.MACHINE + ) - yield self.m.interrupt_level.eq(0) - yield Tick() + sim.set(self.m.interrupt_level, 0) # wait until ISR returns - while ( - yield self.m.core.csr_generic.m_mode.priv_mode.value - ) == PrivilegeLevel.MACHINE and cycles < self.cycle_count: - cycles += 1 - yield Tick() + await wait_or_timeout( + self.m.core.csr_generic.m_mode.priv_mode.value, lambda value: value != PrivilegeLevel.MACHINE + ) - yield from self.random_wait(5) + await self.random_wait(sim, 5) for reg_id, val in self.expected_regvals.items(): - assert (yield from self.get_arch_reg_val(reg_id)) == val + assert self.get_arch_reg_val(sim, reg_id) == val def test_interrupted_prog(self): bin_src = self.prepare_source(self.source_file) self.m = CoreTestElaboratable(self.gen_params, instr_mem=bin_src["text"], data_mem=bin_src["data"]) with self.run_simulation(self.m) as sim: - sim.add_process(self.run_with_interrupt_process) + sim.add_testbench(self.run_with_interrupt_process) diff --git a/test/transactron/core/test_transactions.py b/test/transactron/core/test_transactions.py index 46ef5f6d7..fd4f9e7e0 100644 --- a/test/transactron/core/test_transactions.py +++ b/test/transactron/core/test_transactions.py @@ -1,5 +1,6 @@ from abc import abstractmethod from unittest.case import TestCase +from amaranth_types import HasElaborate import pytest from amaranth import * from amaranth.sim import * @@ -56,52 +57,52 @@ def count_test(self, sched, cnt): assert len(sched.grant) == cnt assert len(sched.valid) == 1 - def sim_step(self, sched, request, expected_grant): - yield sched.requests.eq(request) - yield Tick() + async def sim_step(self, sim, sched: Scheduler, request: int, expected_grant: int): + sim.set(sched.requests, request) + _, _, valid, grant = await sim.tick().sample(sched.valid, sched.grant) if request == 0: - assert not (yield sched.valid) + assert not valid else: - assert (yield sched.grant) == expected_grant - assert (yield sched.valid) + assert grant == expected_grant + assert valid def test_single(self): sched = Scheduler(1) self.count_test(sched, 1) - def process(): - yield from self.sim_step(sched, 0, 0) - yield from self.sim_step(sched, 1, 1) - yield from self.sim_step(sched, 1, 1) - yield from self.sim_step(sched, 0, 0) + async def process(sim): + await self.sim_step(sim, sched, 0, 0) + await self.sim_step(sim, sched, 1, 1) + await self.sim_step(sim, sched, 1, 1) + await self.sim_step(sim, sched, 0, 0) with self.run_simulation(sched) as sim: - sim.add_process(process) + sim.add_testbench(process) def test_multi(self): sched = Scheduler(4) self.count_test(sched, 4) - def process(): - yield from self.sim_step(sched, 0b0000, 0b0000) - yield from self.sim_step(sched, 0b1010, 0b0010) - yield from self.sim_step(sched, 0b1010, 0b1000) - yield from self.sim_step(sched, 0b1010, 0b0010) - yield from self.sim_step(sched, 0b1001, 0b1000) - yield from self.sim_step(sched, 0b1001, 0b0001) + async def process(sim): + await self.sim_step(sim, sched, 0b0000, 0b0000) + await self.sim_step(sim, sched, 0b1010, 0b0010) + await self.sim_step(sim, sched, 0b1010, 0b1000) + await self.sim_step(sim, sched, 0b1010, 0b0010) + await self.sim_step(sim, sched, 0b1001, 0b1000) + await self.sim_step(sim, sched, 0b1001, 0b0001) - yield from self.sim_step(sched, 0b1111, 0b0010) - yield from self.sim_step(sched, 0b1111, 0b0100) - yield from self.sim_step(sched, 0b1111, 0b1000) - yield from self.sim_step(sched, 0b1111, 0b0001) + await self.sim_step(sim, sched, 0b1111, 0b0010) + await self.sim_step(sim, sched, 0b1111, 0b0100) + await self.sim_step(sim, sched, 0b1111, 0b1000) + await self.sim_step(sim, sched, 0b1111, 0b0001) - yield from self.sim_step(sched, 0b0000, 0b0000) - yield from self.sim_step(sched, 0b0010, 0b0010) - yield from self.sim_step(sched, 0b0010, 0b0010) + await self.sim_step(sim, sched, 0b0000, 0b0000) + await self.sim_step(sim, sched, 0b0010, 0b0010) + await self.sim_step(sim, sched, 0b0010, 0b0010) with self.run_simulation(sched) as sim: - sim.add_process(process) + sim.add_testbench(process) class TransactionConflictTestCircuit(Elaboratable): @@ -132,14 +133,19 @@ def setup_method(self): random.seed(42) def make_process( - self, io: TestbenchIO, prob: float, src: Iterable[int], tgt: Callable[[int], None], chk: Callable[[int], None] + self, + io: TestbenchIO, + prob: float, + src: Iterable[int], + tgt: Callable[[int], None], + chk: Callable[[int], None], ): - def process(): + async def process(sim): for i in src: while random.random() >= prob: - yield Tick() + await sim.tick() tgt(i) - r = yield from io.call(data=i) + r = await io.call(sim, data=i) chk(r["data"]) return process @@ -193,9 +199,9 @@ def test_calls(self, name, prob1, prob2, probout): self.m = TransactionConflictTestCircuit(self.__class__.scheduler) with self.run_simulation(self.m, add_transaction_module=False) as sim: - sim.add_process(self.make_in1_process(prob1)) - sim.add_process(self.make_in2_process(prob2)) - sim.add_process(self.make_out_process(probout)) + sim.add_testbench(self.make_in1_process(prob1)) + sim.add_testbench(self.make_in2_process(prob2)) + sim.add_testbench(self.make_out_process(probout)) assert not self.in_expected assert not self.out1_expected @@ -210,7 +216,7 @@ def __init__(self): self.t2 = Signal() @abstractmethod - def elaborate(self, platform) -> TModule: + def elaborate(self, platform) -> HasElaborate: raise NotImplementedError @@ -283,22 +289,22 @@ def setup_method(self): def test_priorities(self, priority: Priority): m = self.circuit(priority) - def process(): + async def process(sim): to_do = 5 * [(0, 1), (1, 0), (1, 1)] random.shuffle(to_do) for r1, r2 in to_do: - yield m.r1.eq(r1) - yield m.r2.eq(r2) - yield Settle() - assert (yield m.t1) != (yield m.t2) + sim.set(m.r1, r1) + sim.set(m.r2, r2) + _, t1, t2 = await sim.delay(1e-9).sample(m.t1, m.t2) + assert t1 != t2 if r1 == 1 and r2 == 1: if priority == Priority.LEFT: - assert (yield m.t1) + assert t1 if priority == Priority.RIGHT: - assert (yield m.t2) + assert t2 with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) @parameterized.expand([(Priority.UNDEFINED,), (Priority.LEFT,), (Priority.RIGHT,)]) def test_unsatisfiable(self, priority: Priority): @@ -368,18 +374,18 @@ def setup_method(self): def test_scheduling(self): m = self.circuit() - def process(): + async def process(sim): to_do = 5 * [(0, 1), (1, 0), (1, 1)] random.shuffle(to_do) for r1, r2 in to_do: - yield m.r1.eq(r1) - yield m.r2.eq(r2) - yield Tick() - assert (yield m.t1) == r1 - assert (yield m.t2) == r1 * r2 + sim.set(m.r1, r1) + sim.set(m.r2, r2) + *_, t1, t2 = await sim.tick().sample(m.t1, m.t2) + assert t1 == r1 + assert t2 == r1 * r2 with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) class ScheduleBeforeTestCircuit(SchedulingTestCircuit): @@ -414,18 +420,18 @@ def setup_method(self): def test_schedule_before(self): m = ScheduleBeforeTestCircuit() - def process(): + async def process(sim): to_do = 5 * [(0, 1), (1, 0), (1, 1)] random.shuffle(to_do) for r1, r2 in to_do: - yield m.r1.eq(r1) - yield m.r2.eq(r2) - yield Tick() - assert (yield m.t1) == r1 - assert not (yield m.t2) + sim.set(m.r1, r1) + sim.set(m.r2, r2) + *_, t1, t2 = await sim.tick().sample(m.t1, m.t2) + assert t1 == r1 + assert not t2 with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) class SingleCallerTestCircuit(Elaboratable): diff --git a/test/transactron/lib/test_fifo.py b/test/transactron/lib/test_fifo.py index 39de8929a..b9d0c5745 100644 --- a/test/transactron/lib/test_fifo.py +++ b/test/transactron/lib/test_fifo.py @@ -1,9 +1,8 @@ from amaranth import * -from amaranth.sim import Settle, Tick from transactron.lib import AdapterTrans, BasicFifo -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout +from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout, TestbenchContext from collections import deque from parameterized import parameterized_class import random @@ -44,36 +43,30 @@ def test_randomized(self): self.done = False - def source(): + async def source(sim: TestbenchContext): for _ in range(cycles): - if random.randint(0, 1): - yield # random delay + await self.random_wait_geom(sim, 0.5) v = random.randint(0, (2**fifoc.fifo.width) - 1) - yield from fifoc.fifo_write.call(data=v) expq.appendleft(v) + await fifoc.fifo_write.call(sim, data=v) if random.random() < 0.005: - yield from fifoc.fifo_clear.call() - yield Settle() + await fifoc.fifo_clear.call(sim) + await sim.delay(1e-9) expq.clear() self.done = True - def target(): + async def target(sim: TestbenchContext): while not self.done or expq: - if random.randint(0, 1): - yield Tick() + await self.random_wait_geom(sim, 0.5) - yield from fifoc.fifo_read.call_init() - yield Tick() + v = await fifoc.fifo_read.call_try(sim) - v = yield from fifoc.fifo_read.call_result() if v is not None: - assert v["data"] == expq.pop() - - yield from fifoc.fifo_read.disable() + assert v.data == expq.pop() with self.run_simulation(fifoc) as sim: - sim.add_process(source) - sim.add_process(target) + sim.add_testbench(source) + sim.add_testbench(target) diff --git a/test/transactron/lib/test_transaction_lib.py b/test/transactron/lib/test_transaction_lib.py index 217897347..6932e4985 100644 --- a/test/transactron/lib/test_transaction_lib.py +++ b/test/transactron/lib/test_transaction_lib.py @@ -3,7 +3,6 @@ import random from operator import and_ from functools import reduce -from amaranth.sim import Settle, Tick from typing import Optional, TypeAlias from parameterized import parameterized from collections import deque @@ -11,14 +10,17 @@ from amaranth import * from transactron import * from transactron.lib import * +from transactron.testing.method_mock import MethodMock +from transactron.testing.testbenchio import CallTrigger from transactron.utils._typing import ModuleLike, MethodStruct, RecordDict from transactron.utils import ModuleConnector from transactron.testing import ( SimpleTestCircuit, TestCaseWithSimulator, - TestbenchIO, data_layout, def_method_mock, + TestbenchIO, + TestbenchContext, ) @@ -45,19 +47,19 @@ def do_test_fifo( random.seed(1337) - def writer(): + async def writer(sim: TestbenchContext): for i in range(2**iosize): - yield from m.write.call(data=i) - yield from self.random_wait(writer_rand) + await m.write.call(sim, data=i) + await self.random_wait(sim, writer_rand) - def reader(): + async def reader(sim: TestbenchContext): for i in range(2**iosize): - assert (yield from m.read.call()) == {"data": i} - yield from self.random_wait(reader_rand) + assert (await m.read.call(sim)).data == i + await self.random_wait(sim, reader_rand) with self.run_simulation(m) as sim: - sim.add_process(reader) - sim.add_process(writer) + sim.add_testbench(reader) + sim.add_testbench(writer) class TestFIFO(TestFifoBase): @@ -86,46 +88,35 @@ def test_forwarding(self): m = SimpleTestCircuit(Forwarder(data_layout(iosize))) - def forward_check(x): - yield from m.read.call_init() - yield from m.write.call_init(data=x) - yield Settle() - assert (yield from m.read.call_result()) == {"data": x} - assert (yield from m.write.call_result()) is not None - yield Tick() + async def forward_check(sim: TestbenchContext, x: int): + read_res, write_res = await CallTrigger(sim).call(m.read).call(m.write, data=x) + assert read_res is not None and read_res.data == x + assert write_res is not None - def process(): + async def process(sim: TestbenchContext): # test forwarding behavior for x in range(4): - yield from forward_check(x) + await forward_check(sim, x) # load the overflow buffer - yield from m.read.disable() - yield from m.write.call_init(data=42) - yield Settle() - assert (yield from m.write.call_result()) is not None - yield Tick() + res = await m.write.call_try(sim, data=42) + assert res is not None # writes are not possible now - yield from m.write.call_init(data=84) - yield Settle() - assert (yield from m.write.call_result()) is None - yield Tick() + res = await m.write.call_try(sim, data=42) + assert res is None # read from the overflow buffer, writes still blocked - yield from m.read.enable() - yield from m.write.call_init(data=111) - yield Settle() - assert (yield from m.read.call_result()) == {"data": 42} - assert (yield from m.write.call_result()) is None - yield Tick() + read_res, write_res = await CallTrigger(sim).call(m.read).call(m.write, data=111) + assert read_res is not None and read_res.data == 42 + assert write_res is None # forwarding now works again for x in range(4): - yield from forward_check(x) + await forward_check(sim, x) with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) class TestPipe(TestFifoBase): @@ -165,7 +156,7 @@ def test_mem( transparent=transparent, read_ports=read_ports, write_ports=write_ports, - ) + ), ) data: list[int] = [0 for _ in range(max_addr)] @@ -174,43 +165,39 @@ def test_mem( random.seed(seed) def writer(i): - def process(): + async def process(sim: TestbenchContext): for cycle in range(test_count): d = random.randrange(2**data_width) a = random.randrange(max_addr) - yield from m.writes[i].call(data=d, addr=a) - for _ in range(i + 2 if not transparent else i): - yield Settle() + await m.writes[i].call(sim, data={"data": d}, addr=a) + await sim.delay(1e-9 * (i + 2 if not transparent else i)) data[a] = d - yield from self.random_wait(writer_rand) + await self.random_wait(sim, writer_rand) return process def reader_req(i): - def process(): + async def process(sim: TestbenchContext): for cycle in range(test_count): a = random.randrange(max_addr) - yield from m.read_reqs[i].call(addr=a) - for _ in range(1 if not transparent else write_ports + 2): - yield Settle() + await m.read_reqs[i].call(sim, addr=a) + await sim.delay(1e-9 * (1 if not transparent else write_ports + 2)) d = data[a] read_req_queues[i].append(d) - yield from self.random_wait(reader_req_rand) + await self.random_wait(sim, reader_req_rand) return process def reader_resp(i): - def process(): + async def process(sim: TestbenchContext): for cycle in range(test_count): - for _ in range(write_ports + 3): - yield Settle() + await sim.delay(1e-9 * (write_ports + 3)) while not read_req_queues[i]: - yield from self.random_wait(reader_resp_rand or 1, min_cycle_cnt=1) - for _ in range(write_ports + 3): - yield Settle() + await self.random_wait(sim, reader_resp_rand or 1, min_cycle_cnt=1) + await sim.delay(1e-9 * (write_ports + 3)) d = read_req_queues[i].popleft() - assert (yield from m.read_resps[i].call()) == {"data": d} - yield from self.random_wait(reader_resp_rand) + assert (await m.read_resps[i].call(sim)).data == d + await self.random_wait(sim, reader_resp_rand) return process @@ -219,10 +206,10 @@ def process(): with self.run_simulation(m, max_cycles=max_cycles) as sim: for i in range(read_ports): - sim.add_process(reader_req(i)) - sim.add_process(reader_resp(i)) + sim.add_testbench(reader_req(i)) + sim.add_testbench(reader_resp(i)) for i in range(write_ports): - sim.add_process(writer(i)) + sim.add_testbench(writer(i)) class TestAsyncMemoryBank(TestCaseWithSimulator): @@ -238,7 +225,7 @@ def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int, m = SimpleTestCircuit( AsyncMemoryBank( data_layout=[("data", data_width)], elem_count=max_addr, read_ports=read_ports, write_ports=write_ports - ) + ), ) data: list[int] = list(0 for i in range(max_addr)) @@ -246,43 +233,41 @@ def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int, random.seed(seed) def writer(i): - def process(): + async def process(sim: TestbenchContext): for cycle in range(test_count): d = random.randrange(2**data_width) a = random.randrange(max_addr) - yield from m.writes[i].call(data=d, addr=a) - for _ in range(i + 2): - yield Settle() + await m.writes[i].call(sim, data={"data": d}, addr=a) + await sim.delay(1e-9 * (i + 2)) data[a] = d - yield from self.random_wait(writer_rand, min_cycle_cnt=1) + await self.random_wait(sim, writer_rand, min_cycle_cnt=1) return process def reader(i): - def process(): + async def process(sim: TestbenchContext): for cycle in range(test_count): a = random.randrange(max_addr) - d = yield from m.reads[i].call(addr=a) - for _ in range(1): - yield Settle() + d = await m.reads[i].call(sim, addr=a) + await sim.delay(1e-9) expected_d = data[a] assert d["data"] == expected_d - yield from self.random_wait(reader_rand, min_cycle_cnt=1) + await self.random_wait(sim, reader_rand, min_cycle_cnt=1) return process with self.run_simulation(m) as sim: for i in range(read_ports): - sim.add_process(reader(i)) + sim.add_testbench(reader(i)) for i in range(write_ports): - sim.add_process(writer(i)) + sim.add_testbench(writer(i)) class ManyToOneConnectTransTestCircuit(Elaboratable): def __init__(self, count: int, lay: MethodLayout): self.count = count self.lay = lay - self.inputs = [] + self.inputs: list[TestbenchIO] = [] def elaborate(self, platform): m = TModule() @@ -296,8 +281,7 @@ def elaborate(self, platform): # Create ManyToOneConnectTrans, which will serialize results from different inputs output = TestbenchIO(Adapter(i=self.lay)) - m.submodules.output = output - self.output = output + m.submodules.output = self.output = output m.submodules.fu_arbitration = ManyToOneConnectTrans(get_results=get_results, put_result=output.adapter.iface) return m @@ -341,19 +325,19 @@ def generate_producer(self, i: int): results to its output FIFO. This records will be next serialized by FUArbiter. """ - def producer(): + async def producer(sim: TestbenchContext): inputs = self.inputs[i] for field1, field2 in inputs: - io: TestbenchIO = self.m.inputs[i] - yield from io.call_init(field1=field1, field2=field2) - yield from self.random_wait(self.max_wait) + self.m.inputs[i].call_init(sim, field1=field1, field2=field2) + await self.random_wait(sim, self.max_wait) self.producer_end[i] = True return producer - def consumer(self): + async def consumer(self, sim: TestbenchContext): + # TODO: this test doesn't test anything, needs to be fixed! while reduce(and_, self.producer_end, True): - result = yield from self.m.output.call_do() + result = await self.m.output.call_do(sim) assert result is not None @@ -363,23 +347,16 @@ def consumer(self): del self.expected_output[t] else: self.expected_output[t] -= 1 - yield from self.random_wait(self.max_wait) - - def test_one_out(self): - self.count = 1 - self.initialize() - with self.run_simulation(self.m) as sim: - sim.add_process(self.consumer) - for i in range(self.count): - sim.add_process(self.generate_producer(i)) + await self.random_wait(sim, self.max_wait) - def test_many_out(self): - self.count = 4 + @pytest.mark.parametrize("count", [1, 4]) + def test(self, count: int): + self.count = count self.initialize() with self.run_simulation(self.m) as sim: - sim.add_process(self.consumer) + sim.add_testbench(self.consumer) for i in range(self.count): - sim.add_process(self.generate_producer(i)) + sim.add_testbench(self.generate_producer(i)) class MethodMapTestCircuit(Elaboratable): @@ -446,11 +423,11 @@ def _(arg: MethodStruct): class TestMethodTransformer(TestCaseWithSimulator): m: MethodMapTestCircuit - def source(self): + async def source(self, sim: TestbenchContext): for i in range(2**self.m.iosize): - v = yield from self.m.source.call(data=i) + v = await self.m.source.call(sim, data=i) i1 = (i + 1) & ((1 << self.m.iosize) - 1) - assert v["data"] == (((i1 << 1) | (i1 >> (self.m.iosize - 1))) - 1) & ((1 << self.m.iosize) - 1) + assert v.data == (((i1 << 1) | (i1 >> (self.m.iosize - 1))) - 1) & ((1 << self.m.iosize) - 1) @def_method_mock(lambda self: self.m.target) def target(self, data): @@ -459,18 +436,17 @@ def target(self, data): def test_method_transformer(self): self.m = MethodMapTestCircuit(4, False, False) with self.run_simulation(self.m) as sim: - sim.add_process(self.source) - sim.add_process(self.target) + sim.add_testbench(self.source) def test_method_transformer_dicts(self): self.m = MethodMapTestCircuit(4, False, True) with self.run_simulation(self.m) as sim: - sim.add_process(self.source) + sim.add_testbench(self.source) def test_method_transformer_with_methods(self): self.m = MethodMapTestCircuit(4, True, True) with self.run_simulation(self.m) as sim: - sim.add_process(self.source) + sim.add_testbench(self.source) class TestMethodFilter(TestCaseWithSimulator): @@ -480,34 +456,31 @@ def initialize(self): self.target = TestbenchIO(Adapter(i=self.layout, o=self.layout)) self.cmeth = TestbenchIO(Adapter(i=self.layout, o=data_layout(1))) - def source(self): + async def source(self, sim: TestbenchContext): for i in range(2**self.iosize): - v = yield from self.tc.method.call(data=i) + v = await self.tc.method.call(sim, data=i) if i & 1: - assert v["data"] == (i + 1) & ((1 << self.iosize) - 1) + assert v.data == (i + 1) & ((1 << self.iosize) - 1) else: - assert v["data"] == 0 + assert v.data == 0 - @def_method_mock(lambda self: self.target, sched_prio=2) + @def_method_mock(lambda self: self.target) def target_mock(self, data): return {"data": data + 1} - @def_method_mock(lambda self: self.cmeth, sched_prio=1) + @def_method_mock(lambda self: self.cmeth) def cmeth_mock(self, data): return {"data": data % 2} - @parameterized.expand([(True,), (False,)]) - def test_method_filter_with_methods(self, use_condition): + def test_method_filter_with_methods(self): self.initialize() - self.tc = SimpleTestCircuit( - MethodFilter(self.target.adapter.iface, self.cmeth.adapter.iface, use_condition=use_condition) - ) + self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, self.cmeth.adapter.iface)) m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) with self.run_simulation(m) as sim: - sim.add_process(self.source) + sim.add_testbench(self.source) @parameterized.expand([(True,), (False,)]) - def test_method_filter(self, use_condition): + def test_method_filter_plain(self, use_condition): self.initialize() def condition(_, v): @@ -516,7 +489,7 @@ def condition(_, v): self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition, use_condition=use_condition)) m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) with self.run_simulation(m) as sim: - sim.add_process(self.source) + sim.add_testbench(self.source) class MethodProductTestCircuit(Elaboratable): @@ -562,36 +535,36 @@ def test_method_product(self, targets: int, add_combiner: bool): def target_process(k: int): @def_method_mock(lambda: m.target[k], enable=lambda: method_en[k]) - def process(data): + def mock(data): return {"data": data + k} - return process + return mock() - def method_process(): + async def method_process(sim: TestbenchContext): # if any of the target methods is not enabled, call does not succeed for i in range(2**targets - 1): for k in range(targets): method_en[k] = bool(i & (1 << k)) - yield Tick() - assert (yield from m.method.call_try(data=0)) is None + await sim.tick() + assert (await m.method.call_try(sim, data=0)) is None # otherwise, the call succeeds for k in range(targets): method_en[k] = True - yield Tick() + await sim.tick() data = random.randint(0, (1 << iosize) - 1) - val = (yield from m.method.call(data=data))["data"] + val = (await m.method.call(sim, data=data)).data if add_combiner: assert val == (targets * data + (targets - 1) * targets // 2) & ((1 << iosize) - 1) else: assert val == data with self.run_simulation(m) as sim: - sim.add_process(method_process) + sim.add_testbench(method_process) for k in range(targets): - sim.add_process(target_process(k)) + self.add_mock(sim, target_process(k)) class TestSerializer(TestCaseWithSimulator): @@ -613,7 +586,7 @@ def setup_method(self): port_count=self.port_count, serialized_req_method=self.req_method.adapter.iface, serialized_resp_method=self.resp_method.adapter.iface, - ) + ), ) self.m = ModuleConnector( test_circuit=self.test_circuit, req_method=self.req_method, resp_method=self.resp_method @@ -628,38 +601,44 @@ def setup_method(self): @def_method_mock(lambda self: self.req_method, enable=lambda self: not self.got_request) def serial_req_mock(self, field): - self.serialized_data.append(field) - self.got_request = True + @MethodMock.effect + def eff(): + self.serialized_data.append(field) + self.got_request = True @def_method_mock(lambda self: self.resp_method, enable=lambda self: self.got_request) def serial_resp_mock(self): - self.got_request = False - return {"field": self.serialized_data[-1]} + @MethodMock.effect + def eff(): + self.got_request = False + + if self.serialized_data: + return {"field": self.serialized_data[-1]} def requestor(self, i: int): - def f(): + async def f(sim: TestbenchContext): for _ in range(self.test_count): d = random.randrange(2**self.data_width) - yield from self.test_circuit.serialize_in[i].call(field=d) + await self.test_circuit.serialize_in[i].call(sim, field=d) self.port_data[i].append(d) - yield from self.random_wait(self.requestor_rand, min_cycle_cnt=1) + await self.random_wait(sim, self.requestor_rand, min_cycle_cnt=1) return f def responder(self, i: int): - def f(): + async def f(sim: TestbenchContext): for _ in range(self.test_count): - data_out = yield from self.test_circuit.serialize_out[i].call() - assert self.port_data[i].popleft() == data_out["field"] - yield from self.random_wait(self.requestor_rand, min_cycle_cnt=1) + data_out = await self.test_circuit.serialize_out[i].call(sim) + assert self.port_data[i].popleft() == data_out.field + await self.random_wait(sim, self.requestor_rand, min_cycle_cnt=1) return f def test_serial(self): with self.run_simulation(self.m) as sim: for i in range(self.port_count): - sim.add_process(self.requestor(i)) - sim.add_process(self.responder(i)) + sim.add_testbench(self.requestor(i)) + sim.add_testbench(self.responder(i)) class TestMethodTryProduct(TestCaseWithSimulator): @@ -674,32 +653,32 @@ def test_method_try_product(self, targets: int, add_combiner: bool): def target_process(k: int): @def_method_mock(lambda: m.target[k], enable=lambda: method_en[k]) - def process(data): + def mock(data): return {"data": data + k} - return process + return mock() - def method_process(): + async def method_process(sim: TestbenchContext): for i in range(2**targets): for k in range(targets): method_en[k] = bool(i & (1 << k)) active_targets = sum(method_en) - yield Tick() + await sim.tick() data = random.randint(0, (1 << iosize) - 1) - val = yield from m.method.call(data=data) + val = await m.method.call(sim, data=data) if add_combiner: adds = sum(k * method_en[k] for k in range(targets)) - assert val == {"data": (active_targets * data + adds) & ((1 << iosize) - 1)} + assert val.data == (active_targets * data + adds) & ((1 << iosize) - 1) else: - assert val == {} + assert val.shape().size == 0 with self.run_simulation(m) as sim: - sim.add_process(method_process) + sim.add_testbench(method_process) for k in range(targets): - sim.add_process(target_process(k)) + self.add_mock(sim, target_process(k)) class MethodTryProductTestCircuit(Elaboratable): @@ -768,7 +747,7 @@ def test_condition(self, nonblocking: bool, priority: bool, catchall: bool): target = TestbenchIO(Adapter(i=[("cond", 2)])) circ = SimpleTestCircuit( - ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall) + ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall), ) m = ModuleConnector(test_circuit=circ, target=target) @@ -776,14 +755,17 @@ def test_condition(self, nonblocking: bool, priority: bool, catchall: bool): @def_method_mock(lambda: target) def target_process(cond): - nonlocal selection - selection = cond + @MethodMock.effect + def eff(): + nonlocal selection + selection = cond - def process(): + async def process(sim: TestbenchContext): nonlocal selection + await sim.tick() # TODO workaround for mocks inactive in first cycle for c1, c2, c3 in product([0, 1], [0, 1], [0, 1]): selection = None - res = yield from circ.source.call_try(cond1=c1, cond2=c2, cond3=c3) + res = await circ.source.call_try(sim, cond1=c1, cond2=c2, cond3=c3) if catchall or nonblocking: assert res is not None @@ -801,4 +783,4 @@ def process(): assert selection in [c1, 2 * c2, 3 * c3] with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) diff --git a/test/transactron/test_adapter.py b/test/transactron/test_adapter.py index a5fa73264..93d0611ae 100644 --- a/test/transactron/test_adapter.py +++ b/test/transactron/test_adapter.py @@ -2,8 +2,8 @@ from transactron import Method, def_method, TModule - -from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, ModuleConnector +from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, TestbenchContext +from transactron.utils.amaranth_ext.elaboratables import ModuleConnector class Echo(Elaboratable): @@ -45,12 +45,11 @@ def _(arg): class TestAdapterTrans(TestCaseWithSimulator): - def proc(self): + async def proc(self, sim: TestbenchContext): for _ in range(3): - # this would previously timeout if the output layout was empty (as is in this case) - yield from self.consumer.action.call(data=0) + await self.consumer.action.call(sim, data=0) for expected in [4, 1, 0]: - obtained = (yield from self.echo.action.call(data=expected))["data"] + obtained = (await self.echo.action.call(sim, data=expected)).data assert expected == obtained def test_single(self): @@ -59,4 +58,4 @@ def test_single(self): self.m = ModuleConnector(echo=self.echo, consumer=self.consumer) with self.run_simulation(self.m, max_cycles=100) as sim: - sim.add_process(self.proc) + sim.add_testbench(self.proc) diff --git a/test/transactron/test_connectors.py b/test/transactron/test_connectors.py index ac15a9f9d..e147a2fb6 100644 --- a/test/transactron/test_connectors.py +++ b/test/transactron/test_connectors.py @@ -1,10 +1,8 @@ import random from parameterized import parameterized_class -from amaranth.sim import Settle, Tick - from transactron.lib import StableSelectingNetwork -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, TestbenchContext @parameterized_class( @@ -19,7 +17,7 @@ def test(self): random.seed(42) - def process(): + async def process(sim: TestbenchContext): for _ in range(100): inputs = [random.randrange(2**8) for _ in range(self.n)] valids = [random.randrange(2) for _ in range(self.n)] @@ -27,20 +25,18 @@ def process(): expected_output_prefix = [] for i in range(self.n): - yield m.valids[i].eq(valids[i]) - yield m.inputs[i].data.eq(inputs[i]) + sim.set(m.valids[i], valids[i]) + sim.set(m.inputs[i].data, inputs[i]) if valids[i]: expected_output_prefix.append(inputs[i]) - yield Settle() - for i in range(total): - out = yield m.outputs[i].data + out = sim.get(m.outputs[i].data) assert out == expected_output_prefix[i] - assert (yield m.output_cnt) == total - yield Tick() + assert sim.get(m.output_cnt) == total + await sim.tick() with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) diff --git a/test/transactron/test_methods.py b/test/transactron/test_methods.py index e03ae5f17..e4a5ced78 100644 --- a/test/transactron/test_methods.py +++ b/test/transactron/test_methods.py @@ -154,14 +154,14 @@ def definition(idx: int, foo: Value): circuit = SimpleTestCircuit(TestDefMethods.CircuitTestModule(definition)) - def test_process(): + async def test_process(sim): for k, method in enumerate(circuit.methods): val = random.randrange(0, 2**3) - ret = yield from method.call(foo=val) + ret = await method.call(sim, foo=val) assert ret["foo"] == (val + k) % 2**3 with self.run_simulation(circuit) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) class AdapterCircuit(Elaboratable): @@ -392,13 +392,13 @@ class TestQuadrupleCircuits(TestCaseWithSimulator): def test(self, quadruple): circ = QuadrupleCircuit(quadruple()) - def process(): + async def process(sim): for n in range(1 << (WIDTH - 2)): - out = yield from circ.tb.call(data=n) + out = await circ.tb.call(sim, data=n) assert out["data"] == n * 4 with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class ConditionalCallCircuit(Elaboratable): @@ -483,31 +483,27 @@ class TestConditionals(TestCaseWithSimulator): def test_conditional_call(self): circ = ConditionalCallCircuit() - def process(): - yield from circ.out.disable() - yield from circ.tb.call_init(data=0) - yield Settle() - assert not (yield from circ.out.done()) - assert not (yield from circ.tb.done()) + async def process(sim): + circ.out.disable(sim) + circ.tb.call_init(sim, data=0) + *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) + assert not out_done and not tb_done - yield from circ.out.enable() - yield Settle() - assert not (yield from circ.out.done()) - assert (yield from circ.tb.done()) + circ.out.enable(sim) + *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) + assert not out_done and tb_done - yield from circ.tb.call_init(data=1) - yield Settle() - assert (yield from circ.out.done()) - assert (yield from circ.tb.done()) + circ.tb.call_init(sim, data=1) + *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) + assert out_done and tb_done # the argument is still 1 but the method is not called - yield from circ.tb.disable() - yield Settle() - assert not (yield from circ.out.done()) - assert not (yield from circ.tb.done()) + circ.tb.disable(sim) + *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) + assert not out_done and not tb_done with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) @parameterized.expand( [ @@ -520,18 +516,18 @@ def process(): def test_conditional(self, elaboratable): circ = elaboratable() - def process(): - yield from circ.tb.enable() - yield circ.ready.eq(0) - yield Settle() - assert not (yield from circ.tb.done()) + async def process(sim): + circ.tb.enable(sim) + sim.set(circ.ready, 0) + *_, tb_done = await sim.tick().sample(circ.tb.adapter.done) + assert not tb_done - yield circ.ready.eq(1) - yield Settle() - assert (yield from circ.tb.done()) + sim.set(circ.ready, 1) + *_, tb_done = await sim.tick().sample(circ.tb.adapter.done) + assert tb_done with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class NonexclusiveMethodCircuit(Elaboratable): @@ -559,42 +555,33 @@ class TestNonexclusiveMethod(TestCaseWithSimulator): def test_nonexclusive_method(self): circ = NonexclusiveMethodCircuit() - def process(): + async def process(sim): for x in range(8): t1en = bool(x & 1) t2en = bool(x & 2) mrdy = bool(x & 4) - if t1en: - yield from circ.t1.enable() - else: - yield from circ.t1.disable() + circ.t1.set_enable(sim, t1en) + circ.t2.set_enable(sim, t2en) + sim.set(circ.ready, int(mrdy)) + sim.set(circ.data, x) - if t2en: - yield from circ.t2.enable() - else: - yield from circ.t2.disable() - - if mrdy: - yield circ.ready.eq(1) - else: - yield circ.ready.eq(0) - - yield circ.data.eq(x) - yield Settle() + *_, running, t1_done, t2_done, t1_outputs, t2_outputs = await sim.delay(1e-9).sample( + circ.running, circ.t1.done, circ.t2.done, circ.t1.outputs, circ.t2.outputs + ) - assert bool((yield circ.running)) == ((t1en or t2en) and mrdy) - assert bool((yield from circ.t1.done())) == (t1en and mrdy) - assert bool((yield from circ.t2.done())) == (t2en and mrdy) + assert bool(running) == ((t1en or t2en) and mrdy) + assert bool(t1_done) == (t1en and mrdy) + assert bool(t2_done) == (t2en and mrdy) if t1en and mrdy: - assert (yield from circ.t1.get_outputs()) == {"data": x} + assert t1_outputs["data"] == x if t2en and mrdy: - assert (yield from circ.t2.get_outputs()) == {"data": x} + assert t2_outputs["data"] == x with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class TwoNonexclusiveConflictCircuit(Elaboratable): @@ -638,15 +625,15 @@ class TestConflicting(TestCaseWithSimulator): def test_conflicting(self, test_circuit: Callable[[], TwoNonexclusiveConflictCircuit]): circ = test_circuit() - def process(): - yield from circ.t1.enable() - yield from circ.t2.enable() - yield Settle() + async def process(sim): + circ.t1.enable(sim) + circ.t2.enable(sim) + *_, running1, running2 = await sim.delay(1e-9).sample(circ.running1, circ.running2) - assert not (yield circ.running1) or not (yield circ.running2) + assert not running1 or not running2 with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class CustomCombinerMethodCircuit(Elaboratable): @@ -679,7 +666,7 @@ class TestCustomCombinerMethod(TestCaseWithSimulator): def test_custom_combiner_method(self): circ = CustomCombinerMethodCircuit() - def process(): + async def process(sim): for x in range(8): t1en = bool(x & 1) t2en = bool(x & 2) @@ -690,38 +677,30 @@ def process(): val1e = val1 if t1en else 0 val2e = val2 if t2en else 0 - yield from circ.t1.call_init(data=val1) - yield from circ.t2.call_init(data=val2) + circ.t1.call_init(sim, data=val1) + circ.t2.call_init(sim, data=val2) - if t1en: - yield from circ.t1.enable() - else: - yield from circ.t1.disable() + circ.t1.set_enable(sim, t1en) + circ.t2.set_enable(sim, t2en) - if t2en: - yield from circ.t2.enable() - else: - yield from circ.t2.disable() + sim.set(circ.ready, int(mrdy)) - if mrdy: - yield circ.ready.eq(1) - else: - yield circ.ready.eq(0) - - yield Settle() + *_, running, t1_done, t2_done, t1_outputs, t2_outputs = await sim.delay(1e-9).sample( + circ.running, circ.t1.done, circ.t2.done, circ.t1.outputs, circ.t2.outputs + ) - assert bool((yield circ.running)) == ((t1en or t2en) and mrdy) - assert bool((yield from circ.t1.done())) == (t1en and mrdy) - assert bool((yield from circ.t2.done())) == (t2en and mrdy) + assert bool(running) == ((t1en or t2en) and mrdy) + assert bool(t1_done) == (t1en and mrdy) + assert bool(t2_done) == (t2en and mrdy) if t1en and mrdy: - assert (yield from circ.t1.get_outputs()) == {"data": val1e ^ val2e} + assert t1_outputs["data"] == val1e ^ val2e if t2en and mrdy: - assert (yield from circ.t2.get_outputs()) == {"data": val1e ^ val2e} + assert t2_outputs["data"] == val1e ^ val2e with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class DataDependentConditionalCircuit(Elaboratable): @@ -729,8 +708,8 @@ def __init__(self, n=2, ready_function=lambda arg: arg.data != 3): self.method = Method(i=data_layout(n)) self.ready_function = ready_function - self.in_t1 = Signal(from_method_layout(data_layout(n))) - self.in_t2 = Signal(from_method_layout(data_layout(n))) + self.in_t1 = Signal(n) + self.in_t2 = Signal(n) self.ready = Signal() self.req_t1 = Signal() self.req_t2 = Signal() @@ -748,11 +727,11 @@ def _(data): with Transaction().body(m, request=self.req_t1): m.d.comb += self.out_t1.eq(1) - self.method(m, self.in_t1) + self.method(m, data=self.in_t1) with Transaction().body(m, request=self.req_t2): m.d.comb += self.out_t2.eq(1) - self.method(m, self.in_t2) + self.method(m, data=self.in_t2) return m @@ -767,7 +746,7 @@ def base_random(self, f): random.seed(14) self.circ = DataDependentConditionalCircuit(n=self.n, ready_function=f) - def process(): + async def process(sim): for _ in range(self.test_number): in1 = random.randrange(0, 2**self.n) in2 = random.randrange(0, 2**self.n) @@ -775,16 +754,15 @@ def process(): req_t1 = random.randrange(2) req_t2 = random.randrange(2) - yield self.circ.in_t1.eq(in1) - yield self.circ.in_t2.eq(in2) - yield self.circ.req_t1.eq(req_t1) - yield self.circ.req_t2.eq(req_t2) - yield self.circ.ready.eq(m_ready) - yield Settle() + sim.set(self.circ.in_t1, in1) + sim.set(self.circ.in_t2, in2) + sim.set(self.circ.req_t1, req_t1) + sim.set(self.circ.req_t2, req_t2) + sim.set(self.circ.ready, m_ready) - out_m = yield self.circ.out_m - out_t1 = yield self.circ.out_t1 - out_t2 = yield self.circ.out_t2 + *_, out_m, out_t1, out_t2 = await sim.delay(1e-9).sample( + self.circ.out_m, self.circ.out_t1, self.circ.out_t2 + ) if not m_ready or (not req_t1 or in1 == self.bad_number) and (not req_t2 or in2 == self.bad_number): assert out_m == 0 @@ -800,10 +778,10 @@ def process(): assert in1 != self.bad_number or not out_t1 assert in2 != self.bad_number or not out_t2 - yield Tick() + await sim.tick() with self.run_simulation(self.circ, 100) as sim: - sim.add_process(process) + sim.add_testbench(process) def test_random_arg(self): self.base_random(lambda arg: arg.data != self.bad_number) diff --git a/test/transactron/test_metrics.py b/test/transactron/test_metrics.py index c7a0f0765..181dcc839 100644 --- a/test/transactron/test_metrics.py +++ b/test/transactron/test_metrics.py @@ -7,12 +7,11 @@ from parameterized import parameterized_class from amaranth import * -from amaranth.sim import Settle, Tick from transactron.lib.metrics import * from transactron import * -from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit -from transactron.testing.infrastructure import Now +from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, TestbenchContext +from transactron.testing.tick_count import TicksKey from transactron.utils.dependencies import DependencyContext @@ -74,72 +73,62 @@ def test_counter_in_method(self): m = SimpleTestCircuit(CounterInMethodCircuit()) DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - def test_process(): + async def test_process(sim): called_cnt = 0 for _ in range(200): call_now = random.randint(0, 1) == 0 if call_now: - yield from m.method.call() + await m.method.call(sim) + called_cnt += 1 else: - yield Tick() - - # Note that it takes one cycle to update the register value, so here - # we are comparing the "previous" values. - assert called_cnt == (yield m._dut.counter.count.value) + await sim.tick() - if call_now: - called_cnt += 1 + assert called_cnt == sim.get(m._dut.counter.count.value) with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) def test_counter_with_condition_in_method(self): m = SimpleTestCircuit(CounterWithConditionInMethodCircuit()) DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - def test_process(): + async def test_process(sim): called_cnt = 0 for _ in range(200): call_now = random.randint(0, 1) == 0 condition = random.randint(0, 1) if call_now: - yield from m.method.call(cond=condition) + await m.method.call(sim, cond=condition) + called_cnt += condition else: - yield Tick() + await sim.tick() - # Note that it takes one cycle to update the register value, so here - # we are comparing the "previous" values. - assert called_cnt == (yield m._dut.counter.count.value) - - if call_now and condition == 1: - called_cnt += 1 + assert called_cnt == sim.get(m._dut.counter.count.value) with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) def test_counter_with_condition_without_method(self): m = CounterWithoutMethodCircuit() DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - def test_process(): + async def test_process(sim): called_cnt = 0 for _ in range(200): condition = random.randint(0, 1) - yield m.cond.eq(condition) - yield Tick() - - # Note that it takes one cycle to update the register value, so here - # we are comparing the "previous" values. - assert called_cnt == (yield m.counter.count.value) + sim.set(m.cond, condition) + await sim.tick() if condition == 1: called_cnt += 1 + assert called_cnt == sim.get(m.counter.count.value) + with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) class OneHotEnum(IntFlag): @@ -184,23 +173,23 @@ def do_test_enum(self, tags: range | Type[Enum] | list[int], tag_values: list[in for i in tag_values: counts[i] = 0 - def test_process(): + async def test_process(sim): for _ in range(200): for i in tag_values: - assert counts[i] == (yield m.counter.counters[i].value) + assert counts[i] == sim.get(m.counter.counters[i].value) tag = random.choice(list(tag_values)) - yield m.cond.eq(1) - yield m.tag.eq(tag) - yield Tick() - yield m.cond.eq(0) - yield Tick() + sim.set(m.cond, 1) + sim.set(m.tag, tag) + await sim.tick() + sim.set(m.cond, 0) + await sim.tick() counts[tag] += 1 with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) def test_one_hot_enum(self): self.do_test_enum(OneHotEnum, [e.value for e in OneHotEnum]) @@ -261,8 +250,8 @@ def test_histogram(self): max_sample_value = 2**self.sample_width - 1 - def test_process(): - min = max_sample_value + 1 + async def test_process(sim): + min = max_sample_value max = 0 sum = 0 count = 0 @@ -282,46 +271,43 @@ def test_process(): if value < 2**i or i == self.bucket_count - 1: buckets[i] += 1 break - yield from m.method.call(data=value) - yield Tick() + await m.method.call(sim, data=value) else: - yield Tick() + await sim.tick() histogram = m._dut.histogram - # Skip the assertion if the min is still uninitialized - if min != max_sample_value + 1: - assert min == (yield histogram.min.value) - assert max == (yield histogram.max.value) - assert sum == (yield histogram.sum.value) - assert count == (yield histogram.count.value) + assert min == sim.get(histogram.min.value) + assert max == sim.get(histogram.max.value) + assert sum == sim.get(histogram.sum.value) + assert count == sim.get(histogram.count.value) total_count = 0 for i in range(self.bucket_count): - bucket_value = yield histogram.buckets[i].value + bucket_value = sim.get(histogram.buckets[i].value) total_count += bucket_value assert buckets[i] == bucket_value # Sanity check if all buckets sum up to the total count value - assert total_count == (yield histogram.count.value) + assert total_count == sim.get(histogram.count.value) with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) class TestLatencyMeasurerBase(TestCaseWithSimulator): - def check_latencies(self, m: SimpleTestCircuit, latencies: list[int]): - assert min(latencies) == (yield m._dut.histogram.min.value) - assert max(latencies) == (yield m._dut.histogram.max.value) - assert sum(latencies) == (yield m._dut.histogram.sum.value) - assert len(latencies) == (yield m._dut.histogram.count.value) + def check_latencies(self, sim, m: SimpleTestCircuit, latencies: list[int]): + assert min(latencies) == sim.get(m._dut.histogram.min.value) + assert max(latencies) == sim.get(m._dut.histogram.max.value) + assert sum(latencies) == sim.get(m._dut.histogram.sum.value) + assert len(latencies) == sim.get(m._dut.histogram.count.value) for i in range(m._dut.histogram.bucket_count): bucket_start = 0 if i == 0 else 2 ** (i - 1) bucket_end = 1e10 if i == m._dut.histogram.bucket_count - 1 else 2**i count = sum(1 for x in latencies if bucket_start <= x < bucket_end) - assert count == (yield m._dut.histogram.buckets[i].value) + assert count == sim.get(m._dut.histogram.buckets[i].value) @parameterized_class( @@ -351,36 +337,33 @@ def test_latency_measurer(self): finish = False - def producer(): + async def producer(sim: TestbenchContext): nonlocal finish + ticks = DependencyContext.get().get_dependency(TicksKey()) for _ in range(200): - yield from m._start.call() + await m._start.call(sim) - # Make sure that the time is updated first. - yield Settle() - time = yield Now() - event_queue.put(time) - yield from self.random_wait_geom(0.8) + event_queue.put(sim.get(ticks)) + await self.random_wait_geom(sim, 0.8) finish = True - def consumer(): + async def consumer(sim: TestbenchContext): + ticks = DependencyContext.get().get_dependency(TicksKey()) + while not finish: - yield from m._stop.call() + await m._stop.call(sim) - # Make sure that the time is updated first. - yield Settle() - time = yield Now() - latencies.append(time - event_queue.get()) + latencies.append(sim.get(ticks) - event_queue.get()) - yield from self.random_wait_geom(1.0 / self.expected_consumer_wait) + await self.random_wait_geom(sim, 1.0 / self.expected_consumer_wait) - self.check_latencies(m, latencies) + self.check_latencies(sim, m, latencies) with self.run_simulation(m) as sim: - sim.add_process(producer) - sim.add_process(consumer) + sim.add_testbench(producer) + sim.add_testbench(consumer) @parameterized_class( @@ -412,54 +395,51 @@ def test_latency_measurer(self): finish = False - def producer(): + async def producer(sim): nonlocal finish + tick_count = DependencyContext.get().get_dependency(TicksKey()) + for _ in range(200): while not free_slots: - yield Tick() - continue - yield Settle() + await sim.tick() + await sim.delay(1e-9) slot_id = random.choice(free_slots) - yield from m._start.call(slot=slot_id) + await m._start.call(sim, slot=slot_id) - time = yield Now() - - events[slot_id] = time + events[slot_id] = sim.get(tick_count) free_slots.remove(slot_id) used_slots.append(slot_id) - yield from self.random_wait_geom(0.8) + await self.random_wait_geom(sim, 0.8) finish = True - def consumer(): + async def consumer(sim): + tick_count = DependencyContext.get().get_dependency(TicksKey()) + while not finish: while not used_slots: - yield Tick() - continue + await sim.tick() slot_id = random.choice(used_slots) - yield from m._stop.call(slot=slot_id) - - time = yield Now() + await m._stop.call(sim, slot=slot_id) - yield Settle() - yield Settle() + await sim.delay(2e-9) - latencies.append(time - events[slot_id]) + latencies.append(sim.get(tick_count) - events[slot_id]) used_slots.remove(slot_id) free_slots.append(slot_id) - yield from self.random_wait_geom(1.0 / self.expected_consumer_wait) + await self.random_wait_geom(sim, 1.0 / self.expected_consumer_wait) - self.check_latencies(m, latencies) + self.check_latencies(sim, m, latencies) with self.run_simulation(m) as sim: - sim.add_process(producer) - sim.add_process(consumer) + sim.add_testbench(producer) + sim.add_testbench(consumer) class MetricManagerTestCircuit(Elaboratable): @@ -527,21 +507,20 @@ def test_returned_reg_values(self): DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - def test_process(): + async def test_process(sim): counters = [0] * 3 for _ in range(200): rand = [random.randint(0, 1) for _ in range(3)] - yield from m.incr_counters.call(counter1=rand[0], counter2=rand[1], counter3=rand[2]) - yield Tick() + await m.incr_counters.call(sim, counter1=rand[0], counter2=rand[1], counter3=rand[2]) for i in range(3): if rand[i] == 1: counters[i] += 1 - assert counters[0] == (yield metrics_manager.get_register_value("foo.counter1", "count")) - assert counters[1] == (yield metrics_manager.get_register_value("bar.baz.counter2", "count")) - assert counters[2] == (yield metrics_manager.get_register_value("bar.baz.counter3", "count")) + assert counters[0] == sim.get(metrics_manager.get_register_value("foo.counter1", "count")) + assert counters[1] == sim.get(metrics_manager.get_register_value("bar.baz.counter2", "count")) + assert counters[2] == sim.get(metrics_manager.get_register_value("bar.baz.counter3", "count")) with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) diff --git a/test/transactron/test_simultaneous.py b/test/transactron/test_simultaneous.py index d0859301d..ad492e330 100644 --- a/test/transactron/test_simultaneous.py +++ b/test/transactron/test_simultaneous.py @@ -3,10 +3,12 @@ from typing import Optional from amaranth import * from amaranth.sim import * +from transactron.testing.method_mock import MethodMock, def_method_mock +from transactron.testing.testbenchio import TestbenchIO from transactron.utils import ModuleConnector -from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchIO, def_method_mock +from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchContext from transactron import * from transactron.lib import Adapter, Connect, ConnectTrans @@ -44,17 +46,17 @@ class TestSimultaneousDiamond(TestCaseWithSimulator): def test_diamond(self): circ = SimpleTestCircuit(SimultaneousDiamondTestCircuit()) - def process(): + async def process(sim: TestbenchContext): methods = {"l": circ.method_l, "r": circ.method_r, "u": circ.method_u, "d": circ.method_d} for i in range(1 << len(methods)): enables: dict[str, bool] = {} for k, n in enumerate(methods): enables[n] = bool(i & (1 << k)) - yield from methods[n].set_enable(enables[n]) - yield Tick() + methods[n].set_enable(sim, enables[n]) dones: dict[str, bool] = {} for n in methods: - dones[n] = bool((yield from methods[n].done())) + dones[n] = bool(methods[n].get_done(sim)) + await sim.tick() for n in methods: if not enables[n]: assert not dones[n] @@ -66,7 +68,7 @@ def process(): assert not any(dones.values()) with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class UnsatisfiableTriangleTestCircuit(Elaboratable): @@ -148,17 +150,19 @@ def test_transitivity(self): result: Optional[int] @def_method_mock(lambda: target) - def target_process(data): - nonlocal result - result = data + def target_process(data: int): + @MethodMock.effect + def eff(): + nonlocal result + result = data - def process(): + async def process(sim: TestbenchContext): nonlocal result for source, data, reqv1, reqv2 in product([circ.source1, circ.source2], [0, 1, 2, 3], [0, 1], [0, 1]): result = None - yield req1.eq(reqv1) - yield req2.eq(reqv2) - call_result = yield from source.call_try(data=data) + sim.set(req1, reqv1) + sim.set(req2, reqv2) + call_result = await source.call_try(sim, data=data) if not reqv1 and not reqv2: assert call_result is None @@ -169,4 +173,4 @@ def process(): assert result in possibles with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) diff --git a/test/transactron/test_transactron_lib_storage.py b/test/transactron/test_transactron_lib_storage.py index 404c14a2d..d5513fe7c 100644 --- a/test/transactron/test_transactron_lib_storage.py +++ b/test/transactron/test_transactron_lib_storage.py @@ -31,28 +31,24 @@ def generic_process( settle_count=0, name="", ): - def f(): + async def f(sim: TestbenchContext): while input_lst: # wait till all processes will end the previous cycle - yield from self.multi_settle(4) + await sim.delay(1e-9) elem = input_lst.pop() if isinstance(elem, OpNOP): - yield Tick() + await sim.tick() continue if input_verification is not None and not input_verification(elem): - yield Tick() + await sim.tick() continue - response = yield from method.call(**elem) - yield from self.multi_settle(settle_count) + response = await method.call(sim, **elem) + await sim.delay(settle_count * 1e-9) if behaviour_check is not None: - # Here accesses to circuit are allowed - ret = behaviour_check(elem, response) - if isinstance(ret, Generator): - yield from ret + behaviour_check(elem, response) if state_change is not None: - # It is standard python function by purpose to don't allow accessing circuit state_change(elem, response) - yield Tick() + await sim.tick() return f @@ -77,10 +73,10 @@ def check(elem, response): addr = elem["addr"] frozen_addr = frozenset(addr.items()) if frozen_addr in self.memory: - assert response["not_found"] == 0 - assert response["data"] == self.memory[frozen_addr] + assert response.not_found == 0 + assert data_const_to_dict(response.data) == self.memory[frozen_addr] else: - assert response["not_found"] == 1 + assert response.not_found == 1 return self.generic_process(self.circ.read, in_read, behaviour_check=check, settle_count=0, name="read") @@ -97,7 +93,7 @@ def verify_in(elem): return ret def check(elem, response): - assert response["not_found"] == int(frozenset(elem["addr"].items()) not in self.memory) + assert response.not_found == int(frozenset(elem["addr"].items()) not in self.memory) def modify_state(elem, response): if frozenset(elem["addr"].items()) in self.memory: @@ -129,7 +125,7 @@ def test_random(self, in_push, in_write, in_read, in_remove): with self.reinitialize_fixtures(): self.setUp() with self.run_simulation(self.circ, max_cycles=500) as sim: - sim.add_process(self.push_process(in_push)) - sim.add_process(self.read_process(in_read)) - sim.add_process(self.write_process(in_write)) - sim.add_process(self.remove_process(in_remove)) + sim.add_testbench(self.push_process(in_push)) + sim.add_testbench(self.read_process(in_read)) + sim.add_testbench(self.write_process(in_write)) + sim.add_testbench(self.remove_process(in_remove)) diff --git a/test/transactron/testing/test_infrastructure.py b/test/transactron/testing/test_infrastructure.py deleted file mode 100644 index cfd59ec87..000000000 --- a/test/transactron/testing/test_infrastructure.py +++ /dev/null @@ -1,31 +0,0 @@ -from amaranth import * -from transactron.testing import * - - -class EmptyCircuit(Elaboratable): - def __init__(self): - pass - - def elaborate(self, platform): - m = Module() - return m - - -class TestNow(TestCaseWithSimulator): - def setup_method(self): - self.test_cycles = 10 - self.m = SimpleTestCircuit(EmptyCircuit()) - - def process(self): - for k in range(self.test_cycles): - now = yield Now() - assert k == now - # check if second call don't change the returned value - now = yield Now() - assert k == now - - yield Tick() - - def test_random(self): - with self.run_simulation(self.m, 50) as sim: - sim.add_process(self.process) diff --git a/test/transactron/testing/test_log.py b/test/transactron/testing/test_log.py index 6e6711d8e..97efb9d3b 100644 --- a/test/transactron/testing/test_log.py +++ b/test/transactron/testing/test_log.py @@ -1,10 +1,9 @@ import pytest import re from amaranth import * -from amaranth.sim import Tick from transactron import * -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, TestbenchContext from transactron.lib import logging LOGGER_NAME = "test_logger" @@ -70,13 +69,13 @@ class TestLog(TestCaseWithSimulator): def test_log(self, caplog): m = LogTest() - def proc(): + async def proc(sim: TestbenchContext): for i in range(50): - yield Tick() - yield m.input.eq(i) + await sim.tick() + sim.set(m.input, i) with self.run_simulation(m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) assert re.search( r"WARNING test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] " @@ -93,13 +92,13 @@ def proc(): def test_error_log(self, caplog): m = ErrorLogTest() - def proc(): - yield Tick() - yield m.input.eq(1) + async def proc(sim: TestbenchContext): + await sim.tick() + sim.set(m.input, 1) with pytest.raises(AssertionError): with self.run_simulation(m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) assert re.search( r"ERROR test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] " @@ -110,13 +109,13 @@ def proc(): def test_assertion(self, caplog): m = AssertionTest() - def proc(): - yield Tick() - yield m.input.eq(1) + async def proc(sim: TestbenchContext): + await sim.tick() + sim.set(m.input, 1) with pytest.raises(AssertionError): with self.run_simulation(m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) assert re.search( r"ERROR test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] Output differs", diff --git a/test/transactron/testing/test_validate_arguments.py b/test/transactron/testing/test_validate_arguments.py index 18066ff5d..7e7036975 100644 --- a/test/transactron/testing/test_validate_arguments.py +++ b/test/transactron/testing/test_validate_arguments.py @@ -2,11 +2,12 @@ from amaranth import * from amaranth.sim import * -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout +from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout, TestbenchContext from transactron import * -from transactron.testing.sugar import def_method_mock +from transactron.testing.method_mock import def_method_mock from transactron.lib import * +from transactron.testing.testbenchio import CallTrigger class ValidateArgumentsTestCircuit(Elaboratable): @@ -24,25 +25,26 @@ def elaborate(self, platform): class TestValidateArguments(TestCaseWithSimulator): def control_caller(self, caller: TestbenchIO, method: TestbenchIO): - def process(): + async def process(sim: TestbenchContext): + await sim.tick() for _ in range(100): val = random.randrange(2) pre_accepted_val = self.accepted_val - ret = yield from caller.call_try(data=val) - if ret is None: - assert val != pre_accepted_val or val == pre_accepted_val and (yield from method.done()) - else: + caller_data, method_data = await CallTrigger(sim).call(caller, data=val).sample(method) + if caller_data is not None: assert val == pre_accepted_val - assert ret["data"] == val + assert caller_data.data == val + else: + assert val != pre_accepted_val or val == pre_accepted_val and method_data is not None return process def validate_arguments(self, data: int): return data == self.accepted_val - def changer(self): + async def changer(self, sim: TestbenchContext): for _ in range(50): - yield Tick("sync_neg") + await sim.tick() self.accepted_val = 1 @def_method_mock(tb_getter=lambda self: self.m.method, validate_arguments=validate_arguments) @@ -54,6 +56,6 @@ def test_validate_arguments(self): self.m = ValidateArgumentsTestCircuit() self.accepted_val = 0 with self.run_simulation(self.m) as sim: - sim.add_process(self.changer) - sim.add_process(self.control_caller(self.m.caller1, self.m.method)) - sim.add_process(self.control_caller(self.m.caller2, self.m.method)) + sim.add_testbench(self.changer) + sim.add_testbench(self.control_caller(self.m.caller1, self.m.method)) + sim.add_testbench(self.control_caller(self.m.caller2, self.m.method)) diff --git a/test/transactron/utils/test_amaranth_ext.py b/test/transactron/utils/test_amaranth_ext.py index 7943ccb76..349fc0b87 100644 --- a/test/transactron/utils/test_amaranth_ext.py +++ b/test/transactron/utils/test_amaranth_ext.py @@ -1,5 +1,6 @@ from transactron.testing import * import random +import pytest from transactron.utils.amaranth_ext import MultiPriorityEncoder, RingMultiPriorityEncoder @@ -33,26 +34,25 @@ def get_expected_ring(input_width, output_count, input, first, last): ) class TestPriorityEncoder(TestCaseWithSimulator): def process(self, get_expected): - def f(): + async def f(sim: TestbenchContext): for _ in range(self.test_number): input = random.randrange(2**self.input_width) first = random.randrange(self.input_width) last = random.randrange(self.input_width) - yield self.circ.input.eq(input) + sim.set(self.circ.input, input) try: - yield self.circ.first.eq(first) - yield self.circ.last.eq(last) + sim.set(self.circ.first, first) + sim.set(self.circ.last, last) except AttributeError: pass - yield Settle() expected_output = get_expected(self.input_width, self.output_count, input, first, last) for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids): if ex is None: - assert (yield valid) == 0 + assert sim.get(valid) == 0 else: - assert (yield valid) == 1 - assert (yield real) == ex - yield Delay(1e-7) + assert sim.get(valid) == 1 + assert sim.get(real) == ex + await sim.delay(1e-7) return f @@ -66,7 +66,7 @@ def test_random(self, test_class, verif_f, input_width, output_count): self.circ = test_class(self.input_width, self.output_count) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process(verif_f)) + sim.add_testbench(self.process(verif_f)) @pytest.mark.parametrize("name", ["prio_encoder", None]) def test_static_create_simple(self, test_class, verif_f, name): @@ -100,7 +100,7 @@ def elaborate(self, platform): self.circ = DUT(self.input_width, self.output_count, name) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process(verif_f)) + sim.add_testbench(self.process(verif_f)) @pytest.mark.parametrize("name", ["prio_encoder", None]) def test_static_create(self, test_class, verif_f, name): @@ -132,4 +132,4 @@ def elaborate(self, platform): self.circ = DUT(self.input_width, self.output_count, name) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process(verif_f)) + sim.add_testbench(self.process(verif_f)) diff --git a/test/transactron/utils/test_onehotswitch.py b/test/transactron/utils/test_onehotswitch.py index 9d7dc843f..b0620c0a9 100644 --- a/test/transactron/utils/test_onehotswitch.py +++ b/test/transactron/utils/test_onehotswitch.py @@ -3,7 +3,7 @@ from transactron.utils import OneHotSwitch -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, TestbenchContext from parameterized import parameterized @@ -30,33 +30,30 @@ def elaborate(self, platform): return m -class TestAssign(TestCaseWithSimulator): +class TestOneHotSwitch(TestCaseWithSimulator): @parameterized.expand([(False,), (True,)]) def test_onehotswitch(self, test_zero): circuit = OneHotSwitchCircuit(4, test_zero) - def switch_test_proc(): + async def switch_test_proc(sim: TestbenchContext): for i in range(len(circuit.input)): - yield circuit.input.eq(1 << i) - yield Settle() - assert (yield circuit.output) == i + sim.set(circuit.input, 1 << i) + assert sim.get(circuit.output) == i with self.run_simulation(circuit) as sim: - sim.add_process(switch_test_proc) + sim.add_testbench(switch_test_proc) def test_onehotswitch_zero(self): circuit = OneHotSwitchCircuit(4, True) - def switch_test_proc_zero(): + async def switch_test_proc_zero(sim: TestbenchContext): for i in range(len(circuit.input)): - yield circuit.input.eq(1 << i) - yield Settle() - assert (yield circuit.output) == i - assert not (yield circuit.zero) + sim.set(circuit.input, 1 << i) + assert sim.get(circuit.output) == i + assert not sim.get(circuit.zero) - yield circuit.input.eq(0) - yield Settle() - assert (yield circuit.zero) + sim.set(circuit.input, 0) + assert sim.get(circuit.zero) with self.run_simulation(circuit) as sim: - sim.add_process(switch_test_proc_zero) + sim.add_testbench(switch_test_proc_zero) diff --git a/test/transactron/utils/test_utils.py b/test/transactron/utils/test_utils.py index 63c176169..abd28f420 100644 --- a/test/transactron/utils/test_utils.py +++ b/test/transactron/utils/test_utils.py @@ -78,21 +78,21 @@ def setup_method(self): self.test_number = 40 self.m = PopcountTestCircuit(self.size) - def check(self, n): - yield self.m.sig_in.eq(n) - yield Settle() - out_popcount = yield self.m.sig_out + def check(self, sim: TestbenchContext, n): + sim.set(self.m.sig_in, n) + out_popcount = sim.get(self.m.sig_out) assert out_popcount == n.bit_count(), f"{n:x}" - def process(self): + async def process(self, sim: TestbenchContext): for i in range(self.test_number): n = random.randrange(2**self.size) - yield from self.check(n) - yield from self.check(2**self.size - 1) + self.check(sim, n) + sim.delay(1e-6) + self.check(sim, 2**self.size - 1) def test_popcount(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) class CLZTestCircuit(Elaboratable): @@ -124,21 +124,21 @@ def setup_method(self): self.test_number = 40 self.m = CLZTestCircuit(self.size) - def check(self, n): - yield self.m.sig_in.eq(n) - yield Settle() - out_clz = yield self.m.sig_out + def check(self, sim: TestbenchContext, n): + sim.set(self.m.sig_in, n) + out_clz = sim.get(self.m.sig_out) assert out_clz == (2**self.size) - n.bit_length(), f"{n:x}" - def process(self): + async def process(self, sim: TestbenchContext): for i in range(self.test_number): n = random.randrange(2**self.size) - yield from self.check(n) - yield from self.check(2**self.size - 1) + self.check(sim, n) + sim.delay(1e-6) + self.check(sim, 2**self.size - 1) def test_count_leading_zeros(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) class CTZTestCircuit(Elaboratable): @@ -170,10 +170,9 @@ def setup_method(self): self.test_number = 40 self.m = CTZTestCircuit(self.size) - def check(self, n): - yield self.m.sig_in.eq(n) - yield Settle() - out_ctz = yield self.m.sig_out + def check(self, sim: TestbenchContext, n): + sim.set(self.m.sig_in, n) + out_ctz = sim.get(self.m.sig_out) expected = 0 if n == 0: @@ -185,12 +184,13 @@ def check(self, n): assert out_ctz == expected, f"{n:x}" - def process(self): + async def process(self, sim: TestbenchContext): for i in range(self.test_number): n = random.randrange(2**self.size) - yield from self.check(n) - yield from self.check(2**self.size - 1) + self.check(sim, n) + await sim.delay(1e-6) + self.check(sim, 2**self.size - 1) def test_count_trailing_zeros(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py index cd09816a2..91b94f9c6 100644 --- a/transactron/lib/transformers.py +++ b/transactron/lib/transformers.py @@ -166,6 +166,7 @@ def __init__( use_condition : bool Instead of `m.If` use simultaneus `condition` which allow to execute this filter if the condition is False and target is not ready. + When `use_condition` is true, `condition` must not be a `Method`. src_loc: int | SrcLoc How many stack frames deep the source location is taken from. Alternatively, the source location to use instead of the default. @@ -180,6 +181,8 @@ def __init__( self.condition = condition self.default = default + assert not (use_condition and isinstance(condition, Method)) + def elaborate(self, platform): m = TModule() diff --git a/transactron/testing/__init__.py b/transactron/testing/__init__.py index aa215228e..8c4940038 100644 --- a/transactron/testing/__init__.py +++ b/transactron/testing/__init__.py @@ -1,8 +1,10 @@ +from amaranth.sim._async import TestbenchContext, ProcessContext, SimulatorContext # noqa: F401 from .input_generation import * # noqa: F401 from .functions import * # noqa: F401 from .infrastructure import * # noqa: F401 -from .sugar import * # noqa: F401 +from .method_mock import * # noqa: F401 from .testbenchio import * # noqa: F401 from .profiler import * # noqa: F401 from .logging import * # noqa: F401 +from .tick_count import * # noqa: F401 from transactron.utils import data_layout # noqa: F401 diff --git a/transactron/testing/functions.py b/transactron/testing/functions.py index 7cb69b12e..ee1225154 100644 --- a/transactron/testing/functions.py +++ b/transactron/testing/functions.py @@ -1,31 +1,15 @@ -from amaranth import * -from amaranth.lib.data import Layout, StructLayout, View -from amaranth.sim._pycoro import Command -from typing import TypeVar, Any, Generator, TypeAlias, TYPE_CHECKING, Union -from transactron.utils._typing import RecordIntDict +import amaranth.lib.data as data +from typing import TypeAlias -if TYPE_CHECKING: - from amaranth.hdl._ast import Statement - from .infrastructure import CoreblocksCommand +MethodData: TypeAlias = "data.Const[data.StructLayout]" -T = TypeVar("T") -TestGen: TypeAlias = Generator[Union[Command, Value, "Statement", "CoreblocksCommand", None], Any, T] - - -def get_outputs(field: View) -> TestGen[RecordIntDict]: - # return dict of all signal values in a record because amaranth's simulator can't read all - # values of a View in a single yield - it can only read Values (Signals) - result = {} - layout = field.shape() - assert isinstance(layout, StructLayout) - for name, fld in layout: - val = field[name] - if isinstance(fld.shape, Layout): - result[name] = yield from get_outputs(View(fld.shape, val)) - elif isinstance(val, Value): - result[name] = yield val - else: - raise ValueError - return result +def data_const_to_dict(c: "data.Const[data.Layout]"): + ret = {} + for k, _ in c.shape(): + v = c[k] + if isinstance(v, data.Const): + v = data_const_to_dict(v) + ret[k] = v + return ret diff --git a/transactron/testing/infrastructure.py b/transactron/testing/infrastructure.py index 861428fc3..a67e1cf67 100644 --- a/transactron/testing/infrastructure.py +++ b/transactron/testing/infrastructure.py @@ -4,19 +4,20 @@ import os import random import functools -import warnings from contextlib import contextmanager, nullcontext -from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast, TypeAlias, Optional -from abc import ABC +from collections.abc import Callable +from typing import TypeVar, Generic, Type, TypeGuard, Any, cast, TypeAlias, Optional from amaranth import * from amaranth.sim import * +from amaranth.sim._async import SimulatorContext from transactron.utils.dependencies import DependencyContext, DependencyManager from .testbenchio import TestbenchIO from .profiler import profiler_process, Profile -from .functions import TestGen from .logging import make_logging_process, parse_logging_level, _LogFormatter +from .tick_count import make_tick_count_process from .gtkw_extension import write_vcd_ext +from .method_mock import MethodMock from transactron import Method from transactron.lib import AdapterTrans from transactron.core.keys import TransactionManagerKey @@ -24,6 +25,9 @@ from transactron.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals +__all__ = ["SimpleTestCircuit", "PysimSimulator", "TestCaseWithSimulator"] + + T = TypeVar("T") _T_nested_collection: TypeAlias = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"] @@ -56,7 +60,10 @@ def __getattr__(self, name: str) -> Any: def elaborate(self, platform): def transform_methods_to_testbenchios( container: _T_nested_collection[Method], - ) -> tuple[_T_nested_collection["TestbenchIO"], Union[ModuleConnector, "TestbenchIO"]]: + ) -> tuple[ + _T_nested_collection["TestbenchIO"], + "ModuleConnector | TestbenchIO", + ]: if isinstance(container, list): tb_list = [] mc_list = [] @@ -119,43 +126,6 @@ def elaborate(self, platform) -> HasElaborate: return m -class CoreblocksCommand(ABC): - pass - - -class Now(CoreblocksCommand): - pass - - -class SyncProcessWrapper: - def __init__(self, f): - self.org_process = f - self.current_cycle = 0 - - def _wrapping_function(self): - response = None - org_coroutine = self.org_process() - try: - while True: - # call orginal test process and catch data yielded by it in `command` variable - command = org_coroutine.send(response) - # If process wait for new cycle - if command is None or isinstance(command, Tick): - command = command or Tick() - # TODO: use of other domains can mess up the counter! - if command.domain == "sync": - self.current_cycle += 1 - # forward to amaranth - yield command - elif isinstance(command, Now): - response = self.current_cycle - # Pass everything else to amaranth simulator without modifications - else: - response = yield command - except StopIteration: - pass - - class PysimSimulator(Simulator): def __init__( self, @@ -197,10 +167,6 @@ def __init__( self.deadline = clk_period * max_cycles - def add_process(self, f: Callable[[], TestGen]): - f_wrapped = SyncProcessWrapper(f) - super().add_process(f_wrapped._wrapping_function) - def run(self) -> bool: with self.ctx: self.run_until(self.deadline) @@ -212,34 +178,44 @@ class TestCaseWithSimulator: dependency_manager: DependencyManager @contextmanager - def configure_dependency_context(self): + def _configure_dependency_context(self): self.dependency_manager = DependencyManager() with DependencyContext(self.dependency_manager): yield Tick() - def add_class_mocks(self, sim: PysimSimulator) -> None: + def add_mock(self, sim: PysimSimulator, val: MethodMock): + sim.add_process(val.output_process) + if val.validate_arguments is not None: + sim.add_process(val.validate_arguments_process) + sim.add_testbench(val.effect_process, background=True) + + def _add_class_mocks(self, sim: PysimSimulator) -> None: for key in dir(self): val = getattr(self, key) if hasattr(val, "_transactron_testing_process"): sim.add_process(val) + elif hasattr(val, "_transactron_method_mock"): + self.add_mock(sim, val()) - def add_local_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: + def _add_local_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: for key, val in frame_locals.items(): if hasattr(val, "_transactron_testing_process"): sim.add_process(val) + elif hasattr(val, "_transactron_method_mock"): + self.add_mock(sim, val()) - def add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: - self.add_class_mocks(sim) - self.add_local_mocks(sim, frame_locals) + def _add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: + self._add_class_mocks(sim) + self._add_local_mocks(sim, frame_locals) - def configure_traces(self): + def _configure_traces(self): traces_file = None if "__TRANSACTRON_DUMP_TRACES" in os.environ: traces_file = self._transactron_current_output_file_name self._transactron_infrastructure_traces_file = traces_file @contextmanager - def configure_profiles(self): + def _configure_profiles(self): profile = None if "__TRANSACTRON_PROFILE" in os.environ: @@ -264,7 +240,7 @@ def f(): profile.encode(f"{profile_dir}/{profile_file}.json") @contextmanager - def configure_logging(self): + def _configure_logging(self): def on_error(): assert False, "Simulation finished due to an error" @@ -292,10 +268,11 @@ def reinitialize_fixtures(self): self._transactron_base_output_file_name + "_" + str(self._transactron_hypothesis_iter_counter) ) self._transactron_sim_processes_to_add: list[Callable[[], Optional[Callable]]] = [] - with self.configure_dependency_context(): - self.configure_traces() - with self.configure_profiles(): - with self.configure_logging(): + with self._configure_dependency_context(): + self._configure_traces() + with self._configure_profiles(): + with self._configure_logging(): + self._transactron_sim_processes_to_add.append(make_tick_count_process) yield self._transactron_hypothesis_iter_counter += 1 @@ -323,7 +300,7 @@ def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_tra traces_file=self._transactron_infrastructure_traces_file, clk_period=clk_period, ) - self.add_all_mocks(sim, sys._getframe(2).f_locals) + self._add_all_mocks(sim, sys._getframe(2).f_locals) yield sim @@ -332,34 +309,25 @@ def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_tra if ret is not None: sim.add_process(ret) - with warnings.catch_warnings(): - # TODO: figure out testing without settles! - warnings.filterwarnings("ignore", r"The `Settle` command is deprecated per RFC 27\.") - - res = sim.run() + res = sim.run() assert res, "Simulation time limit exceeded" - def tick(self, cycle_cnt: int = 1): + async def tick(self, sim: SimulatorContext, cycle_cnt: int = 1): """ - Yields for the given number of cycles. + Waits for the given number of cycles. """ - for _ in range(cycle_cnt): - yield Tick() + await sim.tick() - def random_wait(self, max_cycle_cnt: int, *, min_cycle_cnt: int = 0): + async def random_wait(self, sim: SimulatorContext, max_cycle_cnt: int, *, min_cycle_cnt: int = 0): """ Wait for a random amount of cycles in range [min_cycle_cnt, max_cycle_cnt] """ - yield from self.tick(random.randrange(min_cycle_cnt, max_cycle_cnt + 1)) + await self.tick(sim, random.randrange(min_cycle_cnt, max_cycle_cnt + 1)) - def random_wait_geom(self, prob: float = 0.5): + async def random_wait_geom(self, sim: SimulatorContext, prob: float = 0.5): """ Wait till the first success, where there is `prob` probability for success in each cycle. """ while random.random() > prob: - yield Tick() - - def multi_settle(self, settle_count: int = 1): - for _ in range(settle_count): - yield Settle() + await sim.tick() diff --git a/transactron/testing/logging.py b/transactron/testing/logging.py index 7c8edf1dc..2a40d6f1b 100644 --- a/transactron/testing/logging.py +++ b/transactron/testing/logging.py @@ -1,9 +1,12 @@ -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any import logging +import itertools -from amaranth.sim import Passive, Tick +from amaranth.sim._async import ProcessContext from transactron.lib import logging as tlog +from transactron.utils.dependencies import DependencyContext +from .tick_count import TicksKey __all__ = ["make_logging_process", "parse_logging_level"] @@ -31,7 +34,7 @@ def parse_logging_level(str: str) -> tlog.LogLevel: raise ValueError("Log level must be either {error, warn, info, debug} or a non-negative integer.") -_sim_cycle = 0 +_sim_cycle: int = 0 class _LogFormatter(logging.Formatter): @@ -65,17 +68,15 @@ def make_logging_process(level: tlog.LogLevel, namespace_regexp: str, on_error: root_logger = logging.getLogger() - def handle_logs(): - if not (yield combined_trigger): - return + def handle_logs(record_vals: Iterable[int]) -> None: + it = iter(record_vals) for record in records: - if not (yield record.trigger): - continue + trigger = next(it) + values = [next(it) for _ in record.fields] - values: list[int] = [] - for field in record.fields: - values.append((yield field)) + if not trigger: + continue formatted_msg = record.format(*values) @@ -91,15 +92,18 @@ def handle_logs(): if record.level >= logging.ERROR: on_error() - def log_process(): + async def log_process(sim: ProcessContext) -> None: global _sim_cycle - _sim_cycle = 0 - - yield Passive() - while True: - yield Tick("sync_neg") - yield from handle_logs() - yield Tick() - _sim_cycle += 1 + ticks = DependencyContext.get().get_dependency(TicksKey()) + + async for _, _, ticks_val, combined_trigger_val, *record_vals in ( + sim.tick("sync_neg") + .sample(ticks, combined_trigger) + .sample(*itertools.chain(*([record.trigger] + record.fields for record in records))) + ): + if not combined_trigger_val: + continue + _sim_cycle = ticks_val + handle_logs(record_vals) return log_process diff --git a/transactron/testing/method_mock.py b/transactron/testing/method_mock.py new file mode 100644 index 000000000..9587ae19f --- /dev/null +++ b/transactron/testing/method_mock.py @@ -0,0 +1,175 @@ +from contextlib import contextmanager +import functools +from typing import Callable, Any, Optional + +from amaranth.sim._async import SimulatorContext +from transactron.lib.adapters import Adapter +from transactron.utils.transactron_helpers import async_mock_def_helper +from .testbenchio import TestbenchIO +from transactron.utils._typing import RecordIntDict + + +__all__ = ["MethodMock", "def_method_mock"] + + +class MethodMock: + def __init__( + self, + adapter: Adapter, + function: Callable[..., Optional[RecordIntDict]], + *, + validate_arguments: Optional[Callable[..., bool]] = None, + enable: Callable[[], bool] = lambda: True, + delay: float = 0, + ): + self.adapter = adapter + self.function = function + self.validate_arguments = validate_arguments + self.enable = enable + self.delay = delay + self._effects: list[Callable[[], None]] = [] + self._freeze = False + + _current_mock: Optional["MethodMock"] = None + + @staticmethod + def effect(effect: Callable[[], None]): + assert MethodMock._current_mock is not None + MethodMock._current_mock._effects.append(effect) + + @contextmanager + def _context(self): + assert MethodMock._current_mock is None + MethodMock._current_mock = self + try: + yield + finally: + MethodMock._current_mock = None + + async def output_process( + self, + sim: SimulatorContext, + ) -> None: + sync = sim._design.lookup_domain("sync", None) # type: ignore + async for *_, done, arg, clk in sim.changed(self.adapter.done, self.adapter.data_out).edge(sync.clk, 1): + if clk: + self._freeze = True + if not done or self._freeze: + continue + self._effects = [] + with self._context(): + ret = async_mock_def_helper(self, self.function, arg) + sim.set(self.adapter.data_in, ret) + + async def validate_arguments_process(self, sim: SimulatorContext) -> None: + assert self.validate_arguments is not None + sync = sim._design.lookup_domain("sync", None) # type: ignore + async for *args, clk, _ in ( + sim.changed(*(a for a, _ in self.adapter.validators)).edge(sync.clk, 1).edge(self.adapter.en, 1) + ): + assert len(args) == len(self.adapter.validators) # TODO: remove later + if clk: + self._freeze = True + if self._freeze: + continue + for arg, r in zip(args, (r for _, r in self.adapter.validators)): + sim.set(r, async_mock_def_helper(self, self.validate_arguments, arg)) + + async def effect_process(self, sim: SimulatorContext) -> None: + sim.set(self.adapter.en, self.enable()) + async for *_, done in sim.tick().sample(self.adapter.done): + # Disabling the method on each cycle forces an edge when it is reenabled again. + # The method body won't be executed until the effects are done. + sim.set(self.adapter.en, False) + + # First, perform pending effects, updating internal state. + with sim.critical(): + if done: + for eff in self._effects: + eff() + + # Ensure that the effects of all mocks are applied. Delay 0 also does this! + await sim.delay(self.delay) + + # Next, enable the method. The output will be updated by a combinational process. + self._effects = [] + self._freeze = False + sim.set(self.adapter.en, self.enable()) + + +def def_method_mock( + tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], **kwargs +) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], MethodMock]]: + """ + Decorator function to create method mock handlers. It should be applied on + a function which describes functionality which we want to invoke on method call. + This function will be called on every clock cycle when the method is active, + and also on combinational changes to inputs. + + The decorated function can have a single argument `arg`, which receives + the arguments passed to a method as a `data.Const`, or multiple named arguments, + which correspond to named arguments of the method. + + This decorator can be applied to function definitions or method definitions. + When applied to a method definition, lambdas passed to `def_method_mock` + need to take a `self` argument, which should be the first. + + Mocks defined at class level or at test level are automatically discovered and + don't need to be manually added to the simulation. + + Any side effects (state modification, assertions, etc.) need to be guarded + using the `MethodMock.effect` decorator. + + Make sure to defer accessing state, since decorators are evaluated eagerly + during function declaration. + + Parameters + ---------- + tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO] + Function to get the TestbenchIO of the mocked method. + enable : Callable[[], bool] | Callable[[Any], bool] + Function which decides if the method is enabled in a given clock cycle. + validate_arguments : Callable[..., bool] + Function which validates call arguments. This applies only to Adapters + with `with_validate_arguments` set to True. + delay : float + Simulation time delay for method mock calling. Used for synchronization + between different mocks and testbench processes. + + Example + ------- + ``` + @def_method_mock(lambda: m.target[k]) + def process(arg): + return {"data": arg["data"] + k} + ``` + or for class methods + ``` + @def_method_mock(lambda self: self.target[k]) + def process(self, data): + return {"data": data + k} + ``` + """ + + def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[], MethodMock]: + @functools.wraps(func) + def mock(func_self=None, /) -> MethodMock: + f = func + getter: Any = tb_getter + kw = kwargs + if func_self is not None: + getter = getter.__get__(func_self) + f = f.__get__(func_self) + kw = {} + for k, v in kwargs.items(): + bind = getattr(v, "__get__", None) + kw[k] = bind(func_self) if bind else v + tb = getter() + assert isinstance(tb, TestbenchIO) + assert isinstance(tb.adapter, Adapter) + return MethodMock(tb.adapter, f, **kw) + + mock._transactron_method_mock = 1 # type: ignore + return mock + + return decorator diff --git a/transactron/testing/profiler.py b/transactron/testing/profiler.py index 795c7f293..ace2b6327 100644 --- a/transactron/testing/profiler.py +++ b/transactron/testing/profiler.py @@ -1,37 +1,46 @@ -from amaranth.sim import * +from amaranth import Cat +from amaranth.lib.data import StructLayout, View +from amaranth.sim._async import ProcessContext from transactron.core import TransactionManager from transactron.core.manager import MethodMap from transactron.profiler import CycleProfile, MethodSamples, Profile, ProfileData, ProfileSamples, TransactionSamples -from .functions import TestGen __all__ = ["profiler_process"] def profiler_process(transaction_manager: TransactionManager, profile: Profile): - def process() -> TestGen: + async def process(sim: ProcessContext) -> None: profile_data, get_id = ProfileData.make(transaction_manager) method_map = MethodMap(transaction_manager.transactions) profile.transactions_and_methods = profile_data.transactions_and_methods - yield Passive() - while True: - yield Tick("sync_neg") + transaction_sample_layout = StructLayout({"request": 1, "runnable": 1, "grant": 1}) + async for _, _, *data in ( + sim.tick() + .sample( + *( + View(transaction_sample_layout, Cat(transaction.request, transaction.runnable, transaction.grant)) + for transaction in method_map.transactions + ) + ) + .sample(*(method.run for method in method_map.methods)) + ): + transaction_data = data[: len(method_map.transactions)] + method_data = data[len(method_map.transactions) :] samples = ProfileSamples() - for transaction in method_map.transactions: + for transaction, tsample in zip(method_map.transactions, transaction_data): samples.transactions[get_id(transaction)] = TransactionSamples( - bool((yield transaction.request)), - bool((yield transaction.runnable)), - bool((yield transaction.grant)), + bool(tsample.request), + bool(tsample.runnable), + bool(tsample.grant), ) - for method in method_map.methods: - samples.methods[get_id(method)] = MethodSamples(bool((yield method.run))) + for method, run in zip(method_map.methods, method_data): + samples.methods[get_id(method)] = MethodSamples(bool(run)) cprof = CycleProfile.make(samples, profile_data) profile.cycles.append(cprof) - yield Tick() - return process diff --git a/transactron/testing/sugar.py b/transactron/testing/sugar.py deleted file mode 100644 index de1dc5e21..000000000 --- a/transactron/testing/sugar.py +++ /dev/null @@ -1,82 +0,0 @@ -import functools -from typing import Callable, Any, Optional -from .testbenchio import TestbenchIO, TestGen -from transactron.utils._typing import RecordIntDict - - -def def_method_mock( - tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], sched_prio: int = 0, **kwargs -) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], TestGen[None]]]: - """ - Decorator function to create method mock handlers. It should be applied on - a function which describes functionality which we want to invoke on method call. - Such function will be wrapped by `method_handle_loop` and called on each - method invocation. - - Function `f` should take only one argument `arg` - data used in function - invocation - and should return data to be sent as response to the method call. - - Function `f` can also be a method and take two arguments `self` and `arg`, - the data to be passed on to invoke a method. It should return data to be sent - as response to the method call. - - Instead of the `arg` argument, the data can be split into keyword arguments. - - Make sure to defer accessing state, since decorators are evaluated eagerly - during function declaration. - - Parameters - ---------- - tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO] - Function to get the TestbenchIO providing appropriate `method_handle_loop`. - **kwargs - Arguments passed to `method_handle_loop`. - - Example - ------- - ``` - m = TestCircuit() - def target_process(k: int): - @def_method_mock(lambda: m.target[k]) - def process(arg): - return {"data": arg["data"] + k} - return process - ``` - or equivalently - ``` - m = TestCircuit() - def target_process(k: int): - @def_method_mock(lambda: m.target[k], settle=1, enable=False) - def process(data): - return {"data": data + k} - return process - ``` - or for class methods - ``` - @def_method_mock(lambda self: self.target[k], settle=1, enable=False) - def process(self, data): - return {"data": data + k} - ``` - """ - - def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[], TestGen[None]]: - @functools.wraps(func) - def mock(func_self=None, /) -> TestGen[None]: - f = func - getter: Any = tb_getter - kw = kwargs - if func_self is not None: - getter = getter.__get__(func_self) - f = f.__get__(func_self) - kw = {} - for k, v in kwargs.items(): - bind = getattr(v, "__get__", None) - kw[k] = bind(func_self) if bind else v - tb = getter() - assert isinstance(tb, TestbenchIO) - yield from tb.method_handle_loop(f, extra_settle_count=sched_prio, **kw) - - mock._transactron_testing_process = 1 # type: ignore - return mock - - return decorator diff --git a/transactron/testing/testbenchio.py b/transactron/testing/testbenchio.py index 9af6340bf..05531842c 100644 --- a/transactron/testing/testbenchio.py +++ b/transactron/testing/testbenchio.py @@ -1,11 +1,130 @@ +from collections.abc import Generator, Iterable from amaranth import * -from amaranth.sim import Settle, Passive, Tick -from typing import Optional, Callable +from amaranth.lib.data import View, StructLayout +from amaranth.sim._async import SimulatorContext, TestbenchContext +from typing import Any, Optional + from transactron.lib import AdapterBase -from transactron.lib.adapters import Adapter -from transactron.utils import ValueLike, SignalBundle, mock_def_helper, assign -from transactron.utils._typing import RecordIntDictRet, RecordValueDict, RecordIntDict -from .functions import get_outputs, TestGen +from transactron.utils import ValueLike +from .functions import MethodData + + +__all__ = ["CallTrigger", "TestbenchIO"] + + +class CallTrigger: + """A trigger which allows to call multiple methods and sample signals. + + The `call()` and `call_try()` methods on a `TestbenchIO` always wait at least one clock cycle. It follows + that these methods can't be used to perform calls to multiple methods in a single clock cycle. Usually + this is not a problem, as different methods can be called from different simulation processes. But in cases + when more control over the time when different calls happen is needed, this trigger class allows to call + many methods in a single clock cycle. + """ + + def __init__( + self, + sim: SimulatorContext, + _calls: Iterable[ValueLike | tuple["TestbenchIO", Optional[dict[str, Any]]]] = (), + ): + """ + Parameters + ---------- + sim: SimulatorContext + Amaranth simulator context. + """ + self.sim = sim + self.calls_and_values: list[ValueLike | tuple[TestbenchIO, Optional[dict[str, Any]]]] = list(_calls) + + def sample(self, *values: "ValueLike | TestbenchIO"): + """Sample a signal or a method result on a clock edge. + + Values are sampled like in standard Amaranth `TickTrigger`. Sampling a method result works like `call()`, + but the method is not called - another process can do that instead. If the method was not called, the + sampled value is `None`. + + Parameters + ---------- + *values: ValueLike | TestbenchIO + Value or method to sample. + """ + new_calls_and_values: list[ValueLike | tuple["TestbenchIO", None]] = [] + for value in values: + if isinstance(value, TestbenchIO): + new_calls_and_values.append((value, None)) + else: + new_calls_and_values.append(value) + return CallTrigger(self.sim, (*self.calls_and_values, *new_calls_and_values)) + + def call(self, tbio: "TestbenchIO", data: dict[str, Any] = {}, /, **kwdata): + """Call a method and sample its result. + + Adds a method call to the trigger. The method result is sampled on a clock edge. If the call did not + succeed, the sampled value is `None`. + + Parameters + ---------- + tbio: TestbenchIO + The method to call. + data: dict[str, Any] + Method call arguments stored in a dict. + **kwdata: Any + Method call arguments passed as keyword arguments. If keyword arguments are used, + the `data` argument should not be provided. + """ + if data and kwdata: + raise TypeError("call() takes either a single dict or keyword arguments") + return CallTrigger(self.sim, (*self.calls_and_values, (tbio, data or kwdata))) + + async def until_done(self) -> Any: + """Wait until at least one of the calls succeeds. + + The `CallTrigger` normally acts like `TickTrigger`, e.g. awaiting on it advances the clock to the next + clock edge. It is possible that none of the calls could not be performed, for example because the called + methods were not enabled. In case we only want to focus on the cycles when one of the calls succeeded, + `until_done` can be used. This works like `until()` in `TickTrigger`. + """ + async for results in self: + if any(res is not None for res in results): + return results + + def __await__(self) -> Generator: + only_calls = [t for t in self.calls_and_values if isinstance(t, tuple)] + only_values = [t for t in self.calls_and_values if not isinstance(t, tuple)] + + for tbio, data in only_calls: + if data is not None: + tbio.call_init(self.sim, data) + + def layout_for(tbio: TestbenchIO): + return StructLayout({"outputs": tbio.adapter.data_out.shape(), "done": 1}) + + trigger = ( + self.sim.tick() + .sample(*(View(layout_for(tbio), Cat(tbio.outputs, tbio.done)) for tbio, _ in only_calls)) + .sample(*only_values) + ) + _, _, *results = yield from trigger.__await__() + + for tbio, data in only_calls: + if data is not None: + tbio.disable(self.sim) + + values_it = iter(results[len(only_calls) :]) + calls_it = (s.outputs if s.done else None for s in results[: len(only_calls)]) + + def ret(): + for v in self.calls_and_values: + if isinstance(v, tuple): + yield next(calls_it) + else: + yield next(values_it) + + return tuple(ret()) + + async def __aiter__(self): + while True: + yield await self class TestbenchIO(Elaboratable): @@ -19,134 +138,69 @@ def elaborate(self, platform): # Low-level operations - def set_enable(self, en) -> TestGen[None]: - yield self.adapter.en.eq(1 if en else 0) + def set_enable(self, sim: SimulatorContext, en): + sim.set(self.adapter.en, 1 if en else 0) - def enable(self) -> TestGen[None]: - yield from self.set_enable(True) + def enable(self, sim: SimulatorContext): + self.set_enable(sim, True) - def disable(self) -> TestGen[None]: - yield from self.set_enable(False) + def disable(self, sim: SimulatorContext): + self.set_enable(sim, False) - def done(self) -> TestGen[int]: - return (yield self.adapter.done) + @property + def done(self): + return self.adapter.done - def wait_until_done(self) -> TestGen[None]: - while (yield self.adapter.done) != 1: - yield Tick() + @property + def outputs(self): + return self.adapter.data_out - def set_inputs(self, data: RecordValueDict = {}) -> TestGen[None]: - yield from assign(self.adapter.data_in, data) + def set_inputs(self, sim: SimulatorContext, data): + sim.set(self.adapter.data_in, data) - def get_outputs(self) -> TestGen[RecordIntDictRet]: - return (yield from get_outputs(self.adapter.data_out)) + def get_done(self, sim: TestbenchContext): + return sim.get(self.adapter.done) - # Operations for AdapterTrans + def get_outputs(self, sim: TestbenchContext) -> MethodData: + return sim.get(self.adapter.data_out) - def call_init(self, data: RecordValueDict = {}, /, **kwdata: ValueLike | RecordValueDict) -> TestGen[None]: - if data and kwdata: - raise TypeError("call_init() takes either a single dict or keyword arguments") - if not data: - data = kwdata - yield from self.enable() - yield from self.set_inputs(data) + def sample_outputs(self, sim: SimulatorContext): + return sim.tick().sample(self.adapter.data_out) - def call_result(self) -> TestGen[Optional[RecordIntDictRet]]: - if (yield from self.done()): - return (yield from self.get_outputs()) - return None + def sample_outputs_until_done(self, sim: SimulatorContext): + return self.sample_outputs(sim).until(self.adapter.done) - def call_do(self) -> TestGen[RecordIntDict]: - while (outputs := (yield from self.call_result())) is None: - yield Tick() - yield from self.disable() - return outputs + def sample_outputs_done(self, sim: SimulatorContext): + return sim.tick().sample(self.adapter.data_out, self.adapter.done) - def call_try( - self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict - ) -> TestGen[Optional[RecordIntDictRet]]: - if data and kwdata: - raise TypeError("call_try() takes either a single dict or keyword arguments") - if not data: - data = kwdata - yield from self.call_init(data) - yield Tick() - outputs = yield from self.call_result() - yield from self.disable() - return outputs + # Operations for AdapterTrans - def call(self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict) -> TestGen[RecordIntDictRet]: + def call_init(self, sim: SimulatorContext, data={}, /, **kwdata): if data and kwdata: - raise TypeError("call() takes either a single dict or keyword arguments") + raise TypeError("call_init() takes either a single dict or keyword arguments") if not data: data = kwdata - yield from self.call_init(data) - yield Tick() - return (yield from self.call_do()) - - # Operations for Adapter + self.enable(sim) + self.set_inputs(sim, data) - def method_argument(self) -> TestGen[Optional[RecordIntDictRet]]: - return (yield from self.call_result()) + def get_call_result(self, sim: TestbenchContext) -> Optional[MethodData]: + if self.get_done(sim): + return self.get_outputs(sim) + return None - def method_return(self, data: RecordValueDict = {}) -> TestGen[None]: - yield from self.set_inputs(data) + async def call_result(self, sim: SimulatorContext) -> Optional[MethodData]: + *_, data, done = await self.sample_outputs_done(sim) + if done: + return data + return None - def method_handle( - self, - function: Callable[..., Optional[RecordIntDict]], - *, - enable: Optional[Callable[[], bool]] = None, - validate_arguments: Optional[Callable[..., bool]] = None, - extra_settle_count: int = 0, - ) -> TestGen[None]: - enable = enable or (lambda: True) - yield from self.set_enable(enable()) - - def handle_validate_arguments(): - if validate_arguments is not None: - assert isinstance(self.adapter, Adapter) - for a, r in self.adapter.validators: - ret_out = mock_def_helper(self, validate_arguments, (yield from get_outputs(a))) - yield r.eq(ret_out) - for _ in range(extra_settle_count + 1): - yield Settle() - - # One extra Settle() required to propagate enable signal. - for _ in range(extra_settle_count + 1): - yield Settle() - yield from handle_validate_arguments() - while (arg := (yield from self.method_argument())) is None: - yield Tick() - - yield from self.set_enable(enable()) - for _ in range(extra_settle_count + 1): - yield Settle() - yield from handle_validate_arguments() - - ret_out = mock_def_helper(self, function, arg) - yield from self.method_return(ret_out or {}) - yield Tick() - yield from self.set_enable(False) - - def method_handle_loop( - self, - function: Callable[..., Optional[RecordIntDict]], - *, - enable: Optional[Callable[[], bool]] = None, - validate_arguments: Optional[Callable[..., bool]] = None, - extra_settle_count: int = 0, - ) -> TestGen[None]: - yield Passive() - while True: - yield from self.method_handle( - function, - enable=enable, - validate_arguments=validate_arguments, - extra_settle_count=extra_settle_count, - ) + async def call_do(self, sim: SimulatorContext) -> MethodData: + *_, outputs = await self.sample_outputs_until_done(sim) + self.disable(sim) + return outputs - # Debug signals + async def call_try(self, sim: SimulatorContext, data={}, /, **kwdata) -> Optional[MethodData]: + return (await CallTrigger(sim).call(self, data, **kwdata))[0] - def debug_signals(self) -> SignalBundle: - return self.adapter.debug_signals() + async def call(self, sim: SimulatorContext, data={}, /, **kwdata) -> MethodData: + return (await CallTrigger(sim).call(self, data, **kwdata).until_done())[0] diff --git a/transactron/testing/tick_count.py b/transactron/testing/tick_count.py new file mode 100644 index 000000000..a2d3828d6 --- /dev/null +++ b/transactron/testing/tick_count.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass + +from amaranth import Signal +from amaranth.sim._async import ProcessContext + +from transactron.utils.dependencies import DependencyContext, SimpleKey + + +__all__ = ["TicksKey", "make_tick_count_process"] + + +@dataclass(frozen=True) +class TicksKey(SimpleKey[Signal]): + pass + + +def make_tick_count_process(): + ticks = Signal(64) + DependencyContext.get().add_dependency(TicksKey(), ticks) + + async def process(sim: ProcessContext): + async for _, _, ticks_val in sim.tick().sample(ticks): + sim.set(ticks, ticks_val + 1) + + return process diff --git a/transactron/utils/transactron_helpers.py b/transactron/utils/transactron_helpers.py index 1fae8827a..9cb23cd17 100644 --- a/transactron/utils/transactron_helpers.py +++ b/transactron/utils/transactron_helpers.py @@ -8,6 +8,7 @@ from amaranth import * from amaranth import tracer from amaranth.lib.data import StructLayout +import amaranth.lib.data as data __all__ = [ @@ -17,6 +18,7 @@ "def_helper", "method_def_helper", "mock_def_helper", + "async_mock_def_helper", "get_src_loc", "from_method_layout", "make_layout", @@ -103,6 +105,13 @@ def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T: return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg) +def async_mock_def_helper(tb, func: Callable[..., T], arg: "data.Const[StructLayout]") -> T: + marg = {} + for k, _ in arg.shape(): + marg[k] = arg[k] + return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], marg, **marg) + + def method_def_helper(method, func: Callable[..., T], arg: MethodStruct) -> T: kwargs = {k: arg[k] for k in arg.shape().members} return def_helper(f"method definition for {method}", func, MethodStruct, arg, **kwargs)