diff --git a/test/core/test_transactions.py b/test/core/test_transactions.py index 46ef5f6..fd4f9e7 100644 --- a/test/core/test_transactions.py +++ b/test/core/test_transactions.py @@ -1,5 +1,6 @@ from abc import abstractmethod from unittest.case import TestCase +from amaranth_types import HasElaborate import pytest from amaranth import * from amaranth.sim import * @@ -56,52 +57,52 @@ def count_test(self, sched, cnt): assert len(sched.grant) == cnt assert len(sched.valid) == 1 - def sim_step(self, sched, request, expected_grant): - yield sched.requests.eq(request) - yield Tick() + async def sim_step(self, sim, sched: Scheduler, request: int, expected_grant: int): + sim.set(sched.requests, request) + _, _, valid, grant = await sim.tick().sample(sched.valid, sched.grant) if request == 0: - assert not (yield sched.valid) + assert not valid else: - assert (yield sched.grant) == expected_grant - assert (yield sched.valid) + assert grant == expected_grant + assert valid def test_single(self): sched = Scheduler(1) self.count_test(sched, 1) - def process(): - yield from self.sim_step(sched, 0, 0) - yield from self.sim_step(sched, 1, 1) - yield from self.sim_step(sched, 1, 1) - yield from self.sim_step(sched, 0, 0) + async def process(sim): + await self.sim_step(sim, sched, 0, 0) + await self.sim_step(sim, sched, 1, 1) + await self.sim_step(sim, sched, 1, 1) + await self.sim_step(sim, sched, 0, 0) with self.run_simulation(sched) as sim: - sim.add_process(process) + sim.add_testbench(process) def test_multi(self): sched = Scheduler(4) self.count_test(sched, 4) - def process(): - yield from self.sim_step(sched, 0b0000, 0b0000) - yield from self.sim_step(sched, 0b1010, 0b0010) - yield from self.sim_step(sched, 0b1010, 0b1000) - yield from self.sim_step(sched, 0b1010, 0b0010) - yield from self.sim_step(sched, 0b1001, 0b1000) - yield from self.sim_step(sched, 0b1001, 0b0001) + async def process(sim): + await self.sim_step(sim, sched, 0b0000, 0b0000) + await self.sim_step(sim, sched, 0b1010, 0b0010) + await self.sim_step(sim, sched, 0b1010, 0b1000) + await self.sim_step(sim, sched, 0b1010, 0b0010) + await self.sim_step(sim, sched, 0b1001, 0b1000) + await self.sim_step(sim, sched, 0b1001, 0b0001) - yield from self.sim_step(sched, 0b1111, 0b0010) - yield from self.sim_step(sched, 0b1111, 0b0100) - yield from self.sim_step(sched, 0b1111, 0b1000) - yield from self.sim_step(sched, 0b1111, 0b0001) + await self.sim_step(sim, sched, 0b1111, 0b0010) + await self.sim_step(sim, sched, 0b1111, 0b0100) + await self.sim_step(sim, sched, 0b1111, 0b1000) + await self.sim_step(sim, sched, 0b1111, 0b0001) - yield from self.sim_step(sched, 0b0000, 0b0000) - yield from self.sim_step(sched, 0b0010, 0b0010) - yield from self.sim_step(sched, 0b0010, 0b0010) + await self.sim_step(sim, sched, 0b0000, 0b0000) + await self.sim_step(sim, sched, 0b0010, 0b0010) + await self.sim_step(sim, sched, 0b0010, 0b0010) with self.run_simulation(sched) as sim: - sim.add_process(process) + sim.add_testbench(process) class TransactionConflictTestCircuit(Elaboratable): @@ -132,14 +133,19 @@ def setup_method(self): random.seed(42) def make_process( - self, io: TestbenchIO, prob: float, src: Iterable[int], tgt: Callable[[int], None], chk: Callable[[int], None] + self, + io: TestbenchIO, + prob: float, + src: Iterable[int], + tgt: Callable[[int], None], + chk: Callable[[int], None], ): - def process(): + async def process(sim): for i in src: while random.random() >= prob: - yield Tick() + await sim.tick() tgt(i) - r = yield from io.call(data=i) + r = await io.call(sim, data=i) chk(r["data"]) return process @@ -193,9 +199,9 @@ def test_calls(self, name, prob1, prob2, probout): self.m = TransactionConflictTestCircuit(self.__class__.scheduler) with self.run_simulation(self.m, add_transaction_module=False) as sim: - sim.add_process(self.make_in1_process(prob1)) - sim.add_process(self.make_in2_process(prob2)) - sim.add_process(self.make_out_process(probout)) + sim.add_testbench(self.make_in1_process(prob1)) + sim.add_testbench(self.make_in2_process(prob2)) + sim.add_testbench(self.make_out_process(probout)) assert not self.in_expected assert not self.out1_expected @@ -210,7 +216,7 @@ def __init__(self): self.t2 = Signal() @abstractmethod - def elaborate(self, platform) -> TModule: + def elaborate(self, platform) -> HasElaborate: raise NotImplementedError @@ -283,22 +289,22 @@ def setup_method(self): def test_priorities(self, priority: Priority): m = self.circuit(priority) - def process(): + async def process(sim): to_do = 5 * [(0, 1), (1, 0), (1, 1)] random.shuffle(to_do) for r1, r2 in to_do: - yield m.r1.eq(r1) - yield m.r2.eq(r2) - yield Settle() - assert (yield m.t1) != (yield m.t2) + sim.set(m.r1, r1) + sim.set(m.r2, r2) + _, t1, t2 = await sim.delay(1e-9).sample(m.t1, m.t2) + assert t1 != t2 if r1 == 1 and r2 == 1: if priority == Priority.LEFT: - assert (yield m.t1) + assert t1 if priority == Priority.RIGHT: - assert (yield m.t2) + assert t2 with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) @parameterized.expand([(Priority.UNDEFINED,), (Priority.LEFT,), (Priority.RIGHT,)]) def test_unsatisfiable(self, priority: Priority): @@ -368,18 +374,18 @@ def setup_method(self): def test_scheduling(self): m = self.circuit() - def process(): + async def process(sim): to_do = 5 * [(0, 1), (1, 0), (1, 1)] random.shuffle(to_do) for r1, r2 in to_do: - yield m.r1.eq(r1) - yield m.r2.eq(r2) - yield Tick() - assert (yield m.t1) == r1 - assert (yield m.t2) == r1 * r2 + sim.set(m.r1, r1) + sim.set(m.r2, r2) + *_, t1, t2 = await sim.tick().sample(m.t1, m.t2) + assert t1 == r1 + assert t2 == r1 * r2 with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) class ScheduleBeforeTestCircuit(SchedulingTestCircuit): @@ -414,18 +420,18 @@ def setup_method(self): def test_schedule_before(self): m = ScheduleBeforeTestCircuit() - def process(): + async def process(sim): to_do = 5 * [(0, 1), (1, 0), (1, 1)] random.shuffle(to_do) for r1, r2 in to_do: - yield m.r1.eq(r1) - yield m.r2.eq(r2) - yield Tick() - assert (yield m.t1) == r1 - assert not (yield m.t2) + sim.set(m.r1, r1) + sim.set(m.r2, r2) + *_, t1, t2 = await sim.tick().sample(m.t1, m.t2) + assert t1 == r1 + assert not t2 with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) class SingleCallerTestCircuit(Elaboratable): diff --git a/test/lib/test_fifo.py b/test/lib/test_fifo.py index 39de892..b9d0c57 100644 --- a/test/lib/test_fifo.py +++ b/test/lib/test_fifo.py @@ -1,9 +1,8 @@ from amaranth import * -from amaranth.sim import Settle, Tick from transactron.lib import AdapterTrans, BasicFifo -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout +from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout, TestbenchContext from collections import deque from parameterized import parameterized_class import random @@ -44,36 +43,30 @@ def test_randomized(self): self.done = False - def source(): + async def source(sim: TestbenchContext): for _ in range(cycles): - if random.randint(0, 1): - yield # random delay + await self.random_wait_geom(sim, 0.5) v = random.randint(0, (2**fifoc.fifo.width) - 1) - yield from fifoc.fifo_write.call(data=v) expq.appendleft(v) + await fifoc.fifo_write.call(sim, data=v) if random.random() < 0.005: - yield from fifoc.fifo_clear.call() - yield Settle() + await fifoc.fifo_clear.call(sim) + await sim.delay(1e-9) expq.clear() self.done = True - def target(): + async def target(sim: TestbenchContext): while not self.done or expq: - if random.randint(0, 1): - yield Tick() + await self.random_wait_geom(sim, 0.5) - yield from fifoc.fifo_read.call_init() - yield Tick() + v = await fifoc.fifo_read.call_try(sim) - v = yield from fifoc.fifo_read.call_result() if v is not None: - assert v["data"] == expq.pop() - - yield from fifoc.fifo_read.disable() + assert v.data == expq.pop() with self.run_simulation(fifoc) as sim: - sim.add_process(source) - sim.add_process(target) + sim.add_testbench(source) + sim.add_testbench(target) diff --git a/test/lib/test_transaction_lib.py b/test/lib/test_transaction_lib.py index 2178973..6932e49 100644 --- a/test/lib/test_transaction_lib.py +++ b/test/lib/test_transaction_lib.py @@ -3,7 +3,6 @@ import random from operator import and_ from functools import reduce -from amaranth.sim import Settle, Tick from typing import Optional, TypeAlias from parameterized import parameterized from collections import deque @@ -11,14 +10,17 @@ from amaranth import * from transactron import * from transactron.lib import * +from transactron.testing.method_mock import MethodMock +from transactron.testing.testbenchio import CallTrigger from transactron.utils._typing import ModuleLike, MethodStruct, RecordDict from transactron.utils import ModuleConnector from transactron.testing import ( SimpleTestCircuit, TestCaseWithSimulator, - TestbenchIO, data_layout, def_method_mock, + TestbenchIO, + TestbenchContext, ) @@ -45,19 +47,19 @@ def do_test_fifo( random.seed(1337) - def writer(): + async def writer(sim: TestbenchContext): for i in range(2**iosize): - yield from m.write.call(data=i) - yield from self.random_wait(writer_rand) + await m.write.call(sim, data=i) + await self.random_wait(sim, writer_rand) - def reader(): + async def reader(sim: TestbenchContext): for i in range(2**iosize): - assert (yield from m.read.call()) == {"data": i} - yield from self.random_wait(reader_rand) + assert (await m.read.call(sim)).data == i + await self.random_wait(sim, reader_rand) with self.run_simulation(m) as sim: - sim.add_process(reader) - sim.add_process(writer) + sim.add_testbench(reader) + sim.add_testbench(writer) class TestFIFO(TestFifoBase): @@ -86,46 +88,35 @@ def test_forwarding(self): m = SimpleTestCircuit(Forwarder(data_layout(iosize))) - def forward_check(x): - yield from m.read.call_init() - yield from m.write.call_init(data=x) - yield Settle() - assert (yield from m.read.call_result()) == {"data": x} - assert (yield from m.write.call_result()) is not None - yield Tick() + async def forward_check(sim: TestbenchContext, x: int): + read_res, write_res = await CallTrigger(sim).call(m.read).call(m.write, data=x) + assert read_res is not None and read_res.data == x + assert write_res is not None - def process(): + async def process(sim: TestbenchContext): # test forwarding behavior for x in range(4): - yield from forward_check(x) + await forward_check(sim, x) # load the overflow buffer - yield from m.read.disable() - yield from m.write.call_init(data=42) - yield Settle() - assert (yield from m.write.call_result()) is not None - yield Tick() + res = await m.write.call_try(sim, data=42) + assert res is not None # writes are not possible now - yield from m.write.call_init(data=84) - yield Settle() - assert (yield from m.write.call_result()) is None - yield Tick() + res = await m.write.call_try(sim, data=42) + assert res is None # read from the overflow buffer, writes still blocked - yield from m.read.enable() - yield from m.write.call_init(data=111) - yield Settle() - assert (yield from m.read.call_result()) == {"data": 42} - assert (yield from m.write.call_result()) is None - yield Tick() + read_res, write_res = await CallTrigger(sim).call(m.read).call(m.write, data=111) + assert read_res is not None and read_res.data == 42 + assert write_res is None # forwarding now works again for x in range(4): - yield from forward_check(x) + await forward_check(sim, x) with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) class TestPipe(TestFifoBase): @@ -165,7 +156,7 @@ def test_mem( transparent=transparent, read_ports=read_ports, write_ports=write_ports, - ) + ), ) data: list[int] = [0 for _ in range(max_addr)] @@ -174,43 +165,39 @@ def test_mem( random.seed(seed) def writer(i): - def process(): + async def process(sim: TestbenchContext): for cycle in range(test_count): d = random.randrange(2**data_width) a = random.randrange(max_addr) - yield from m.writes[i].call(data=d, addr=a) - for _ in range(i + 2 if not transparent else i): - yield Settle() + await m.writes[i].call(sim, data={"data": d}, addr=a) + await sim.delay(1e-9 * (i + 2 if not transparent else i)) data[a] = d - yield from self.random_wait(writer_rand) + await self.random_wait(sim, writer_rand) return process def reader_req(i): - def process(): + async def process(sim: TestbenchContext): for cycle in range(test_count): a = random.randrange(max_addr) - yield from m.read_reqs[i].call(addr=a) - for _ in range(1 if not transparent else write_ports + 2): - yield Settle() + await m.read_reqs[i].call(sim, addr=a) + await sim.delay(1e-9 * (1 if not transparent else write_ports + 2)) d = data[a] read_req_queues[i].append(d) - yield from self.random_wait(reader_req_rand) + await self.random_wait(sim, reader_req_rand) return process def reader_resp(i): - def process(): + async def process(sim: TestbenchContext): for cycle in range(test_count): - for _ in range(write_ports + 3): - yield Settle() + await sim.delay(1e-9 * (write_ports + 3)) while not read_req_queues[i]: - yield from self.random_wait(reader_resp_rand or 1, min_cycle_cnt=1) - for _ in range(write_ports + 3): - yield Settle() + await self.random_wait(sim, reader_resp_rand or 1, min_cycle_cnt=1) + await sim.delay(1e-9 * (write_ports + 3)) d = read_req_queues[i].popleft() - assert (yield from m.read_resps[i].call()) == {"data": d} - yield from self.random_wait(reader_resp_rand) + assert (await m.read_resps[i].call(sim)).data == d + await self.random_wait(sim, reader_resp_rand) return process @@ -219,10 +206,10 @@ def process(): with self.run_simulation(m, max_cycles=max_cycles) as sim: for i in range(read_ports): - sim.add_process(reader_req(i)) - sim.add_process(reader_resp(i)) + sim.add_testbench(reader_req(i)) + sim.add_testbench(reader_resp(i)) for i in range(write_ports): - sim.add_process(writer(i)) + sim.add_testbench(writer(i)) class TestAsyncMemoryBank(TestCaseWithSimulator): @@ -238,7 +225,7 @@ def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int, m = SimpleTestCircuit( AsyncMemoryBank( data_layout=[("data", data_width)], elem_count=max_addr, read_ports=read_ports, write_ports=write_ports - ) + ), ) data: list[int] = list(0 for i in range(max_addr)) @@ -246,43 +233,41 @@ def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int, random.seed(seed) def writer(i): - def process(): + async def process(sim: TestbenchContext): for cycle in range(test_count): d = random.randrange(2**data_width) a = random.randrange(max_addr) - yield from m.writes[i].call(data=d, addr=a) - for _ in range(i + 2): - yield Settle() + await m.writes[i].call(sim, data={"data": d}, addr=a) + await sim.delay(1e-9 * (i + 2)) data[a] = d - yield from self.random_wait(writer_rand, min_cycle_cnt=1) + await self.random_wait(sim, writer_rand, min_cycle_cnt=1) return process def reader(i): - def process(): + async def process(sim: TestbenchContext): for cycle in range(test_count): a = random.randrange(max_addr) - d = yield from m.reads[i].call(addr=a) - for _ in range(1): - yield Settle() + d = await m.reads[i].call(sim, addr=a) + await sim.delay(1e-9) expected_d = data[a] assert d["data"] == expected_d - yield from self.random_wait(reader_rand, min_cycle_cnt=1) + await self.random_wait(sim, reader_rand, min_cycle_cnt=1) return process with self.run_simulation(m) as sim: for i in range(read_ports): - sim.add_process(reader(i)) + sim.add_testbench(reader(i)) for i in range(write_ports): - sim.add_process(writer(i)) + sim.add_testbench(writer(i)) class ManyToOneConnectTransTestCircuit(Elaboratable): def __init__(self, count: int, lay: MethodLayout): self.count = count self.lay = lay - self.inputs = [] + self.inputs: list[TestbenchIO] = [] def elaborate(self, platform): m = TModule() @@ -296,8 +281,7 @@ def elaborate(self, platform): # Create ManyToOneConnectTrans, which will serialize results from different inputs output = TestbenchIO(Adapter(i=self.lay)) - m.submodules.output = output - self.output = output + m.submodules.output = self.output = output m.submodules.fu_arbitration = ManyToOneConnectTrans(get_results=get_results, put_result=output.adapter.iface) return m @@ -341,19 +325,19 @@ def generate_producer(self, i: int): results to its output FIFO. This records will be next serialized by FUArbiter. """ - def producer(): + async def producer(sim: TestbenchContext): inputs = self.inputs[i] for field1, field2 in inputs: - io: TestbenchIO = self.m.inputs[i] - yield from io.call_init(field1=field1, field2=field2) - yield from self.random_wait(self.max_wait) + self.m.inputs[i].call_init(sim, field1=field1, field2=field2) + await self.random_wait(sim, self.max_wait) self.producer_end[i] = True return producer - def consumer(self): + async def consumer(self, sim: TestbenchContext): + # TODO: this test doesn't test anything, needs to be fixed! while reduce(and_, self.producer_end, True): - result = yield from self.m.output.call_do() + result = await self.m.output.call_do(sim) assert result is not None @@ -363,23 +347,16 @@ def consumer(self): del self.expected_output[t] else: self.expected_output[t] -= 1 - yield from self.random_wait(self.max_wait) - - def test_one_out(self): - self.count = 1 - self.initialize() - with self.run_simulation(self.m) as sim: - sim.add_process(self.consumer) - for i in range(self.count): - sim.add_process(self.generate_producer(i)) + await self.random_wait(sim, self.max_wait) - def test_many_out(self): - self.count = 4 + @pytest.mark.parametrize("count", [1, 4]) + def test(self, count: int): + self.count = count self.initialize() with self.run_simulation(self.m) as sim: - sim.add_process(self.consumer) + sim.add_testbench(self.consumer) for i in range(self.count): - sim.add_process(self.generate_producer(i)) + sim.add_testbench(self.generate_producer(i)) class MethodMapTestCircuit(Elaboratable): @@ -446,11 +423,11 @@ def _(arg: MethodStruct): class TestMethodTransformer(TestCaseWithSimulator): m: MethodMapTestCircuit - def source(self): + async def source(self, sim: TestbenchContext): for i in range(2**self.m.iosize): - v = yield from self.m.source.call(data=i) + v = await self.m.source.call(sim, data=i) i1 = (i + 1) & ((1 << self.m.iosize) - 1) - assert v["data"] == (((i1 << 1) | (i1 >> (self.m.iosize - 1))) - 1) & ((1 << self.m.iosize) - 1) + assert v.data == (((i1 << 1) | (i1 >> (self.m.iosize - 1))) - 1) & ((1 << self.m.iosize) - 1) @def_method_mock(lambda self: self.m.target) def target(self, data): @@ -459,18 +436,17 @@ def target(self, data): def test_method_transformer(self): self.m = MethodMapTestCircuit(4, False, False) with self.run_simulation(self.m) as sim: - sim.add_process(self.source) - sim.add_process(self.target) + sim.add_testbench(self.source) def test_method_transformer_dicts(self): self.m = MethodMapTestCircuit(4, False, True) with self.run_simulation(self.m) as sim: - sim.add_process(self.source) + sim.add_testbench(self.source) def test_method_transformer_with_methods(self): self.m = MethodMapTestCircuit(4, True, True) with self.run_simulation(self.m) as sim: - sim.add_process(self.source) + sim.add_testbench(self.source) class TestMethodFilter(TestCaseWithSimulator): @@ -480,34 +456,31 @@ def initialize(self): self.target = TestbenchIO(Adapter(i=self.layout, o=self.layout)) self.cmeth = TestbenchIO(Adapter(i=self.layout, o=data_layout(1))) - def source(self): + async def source(self, sim: TestbenchContext): for i in range(2**self.iosize): - v = yield from self.tc.method.call(data=i) + v = await self.tc.method.call(sim, data=i) if i & 1: - assert v["data"] == (i + 1) & ((1 << self.iosize) - 1) + assert v.data == (i + 1) & ((1 << self.iosize) - 1) else: - assert v["data"] == 0 + assert v.data == 0 - @def_method_mock(lambda self: self.target, sched_prio=2) + @def_method_mock(lambda self: self.target) def target_mock(self, data): return {"data": data + 1} - @def_method_mock(lambda self: self.cmeth, sched_prio=1) + @def_method_mock(lambda self: self.cmeth) def cmeth_mock(self, data): return {"data": data % 2} - @parameterized.expand([(True,), (False,)]) - def test_method_filter_with_methods(self, use_condition): + def test_method_filter_with_methods(self): self.initialize() - self.tc = SimpleTestCircuit( - MethodFilter(self.target.adapter.iface, self.cmeth.adapter.iface, use_condition=use_condition) - ) + self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, self.cmeth.adapter.iface)) m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) with self.run_simulation(m) as sim: - sim.add_process(self.source) + sim.add_testbench(self.source) @parameterized.expand([(True,), (False,)]) - def test_method_filter(self, use_condition): + def test_method_filter_plain(self, use_condition): self.initialize() def condition(_, v): @@ -516,7 +489,7 @@ def condition(_, v): self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition, use_condition=use_condition)) m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) with self.run_simulation(m) as sim: - sim.add_process(self.source) + sim.add_testbench(self.source) class MethodProductTestCircuit(Elaboratable): @@ -562,36 +535,36 @@ def test_method_product(self, targets: int, add_combiner: bool): def target_process(k: int): @def_method_mock(lambda: m.target[k], enable=lambda: method_en[k]) - def process(data): + def mock(data): return {"data": data + k} - return process + return mock() - def method_process(): + async def method_process(sim: TestbenchContext): # if any of the target methods is not enabled, call does not succeed for i in range(2**targets - 1): for k in range(targets): method_en[k] = bool(i & (1 << k)) - yield Tick() - assert (yield from m.method.call_try(data=0)) is None + await sim.tick() + assert (await m.method.call_try(sim, data=0)) is None # otherwise, the call succeeds for k in range(targets): method_en[k] = True - yield Tick() + await sim.tick() data = random.randint(0, (1 << iosize) - 1) - val = (yield from m.method.call(data=data))["data"] + val = (await m.method.call(sim, data=data)).data if add_combiner: assert val == (targets * data + (targets - 1) * targets // 2) & ((1 << iosize) - 1) else: assert val == data with self.run_simulation(m) as sim: - sim.add_process(method_process) + sim.add_testbench(method_process) for k in range(targets): - sim.add_process(target_process(k)) + self.add_mock(sim, target_process(k)) class TestSerializer(TestCaseWithSimulator): @@ -613,7 +586,7 @@ def setup_method(self): port_count=self.port_count, serialized_req_method=self.req_method.adapter.iface, serialized_resp_method=self.resp_method.adapter.iface, - ) + ), ) self.m = ModuleConnector( test_circuit=self.test_circuit, req_method=self.req_method, resp_method=self.resp_method @@ -628,38 +601,44 @@ def setup_method(self): @def_method_mock(lambda self: self.req_method, enable=lambda self: not self.got_request) def serial_req_mock(self, field): - self.serialized_data.append(field) - self.got_request = True + @MethodMock.effect + def eff(): + self.serialized_data.append(field) + self.got_request = True @def_method_mock(lambda self: self.resp_method, enable=lambda self: self.got_request) def serial_resp_mock(self): - self.got_request = False - return {"field": self.serialized_data[-1]} + @MethodMock.effect + def eff(): + self.got_request = False + + if self.serialized_data: + return {"field": self.serialized_data[-1]} def requestor(self, i: int): - def f(): + async def f(sim: TestbenchContext): for _ in range(self.test_count): d = random.randrange(2**self.data_width) - yield from self.test_circuit.serialize_in[i].call(field=d) + await self.test_circuit.serialize_in[i].call(sim, field=d) self.port_data[i].append(d) - yield from self.random_wait(self.requestor_rand, min_cycle_cnt=1) + await self.random_wait(sim, self.requestor_rand, min_cycle_cnt=1) return f def responder(self, i: int): - def f(): + async def f(sim: TestbenchContext): for _ in range(self.test_count): - data_out = yield from self.test_circuit.serialize_out[i].call() - assert self.port_data[i].popleft() == data_out["field"] - yield from self.random_wait(self.requestor_rand, min_cycle_cnt=1) + data_out = await self.test_circuit.serialize_out[i].call(sim) + assert self.port_data[i].popleft() == data_out.field + await self.random_wait(sim, self.requestor_rand, min_cycle_cnt=1) return f def test_serial(self): with self.run_simulation(self.m) as sim: for i in range(self.port_count): - sim.add_process(self.requestor(i)) - sim.add_process(self.responder(i)) + sim.add_testbench(self.requestor(i)) + sim.add_testbench(self.responder(i)) class TestMethodTryProduct(TestCaseWithSimulator): @@ -674,32 +653,32 @@ def test_method_try_product(self, targets: int, add_combiner: bool): def target_process(k: int): @def_method_mock(lambda: m.target[k], enable=lambda: method_en[k]) - def process(data): + def mock(data): return {"data": data + k} - return process + return mock() - def method_process(): + async def method_process(sim: TestbenchContext): for i in range(2**targets): for k in range(targets): method_en[k] = bool(i & (1 << k)) active_targets = sum(method_en) - yield Tick() + await sim.tick() data = random.randint(0, (1 << iosize) - 1) - val = yield from m.method.call(data=data) + val = await m.method.call(sim, data=data) if add_combiner: adds = sum(k * method_en[k] for k in range(targets)) - assert val == {"data": (active_targets * data + adds) & ((1 << iosize) - 1)} + assert val.data == (active_targets * data + adds) & ((1 << iosize) - 1) else: - assert val == {} + assert val.shape().size == 0 with self.run_simulation(m) as sim: - sim.add_process(method_process) + sim.add_testbench(method_process) for k in range(targets): - sim.add_process(target_process(k)) + self.add_mock(sim, target_process(k)) class MethodTryProductTestCircuit(Elaboratable): @@ -768,7 +747,7 @@ def test_condition(self, nonblocking: bool, priority: bool, catchall: bool): target = TestbenchIO(Adapter(i=[("cond", 2)])) circ = SimpleTestCircuit( - ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall) + ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall), ) m = ModuleConnector(test_circuit=circ, target=target) @@ -776,14 +755,17 @@ def test_condition(self, nonblocking: bool, priority: bool, catchall: bool): @def_method_mock(lambda: target) def target_process(cond): - nonlocal selection - selection = cond + @MethodMock.effect + def eff(): + nonlocal selection + selection = cond - def process(): + async def process(sim: TestbenchContext): nonlocal selection + await sim.tick() # TODO workaround for mocks inactive in first cycle for c1, c2, c3 in product([0, 1], [0, 1], [0, 1]): selection = None - res = yield from circ.source.call_try(cond1=c1, cond2=c2, cond3=c3) + res = await circ.source.call_try(sim, cond1=c1, cond2=c2, cond3=c3) if catchall or nonblocking: assert res is not None @@ -801,4 +783,4 @@ def process(): assert selection in [c1, 2 * c2, 3 * c3] with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) diff --git a/test/test_adapter.py b/test/test_adapter.py index a5fa732..93d0611 100644 --- a/test/test_adapter.py +++ b/test/test_adapter.py @@ -2,8 +2,8 @@ from transactron import Method, def_method, TModule - -from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, ModuleConnector +from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, TestbenchContext +from transactron.utils.amaranth_ext.elaboratables import ModuleConnector class Echo(Elaboratable): @@ -45,12 +45,11 @@ def _(arg): class TestAdapterTrans(TestCaseWithSimulator): - def proc(self): + async def proc(self, sim: TestbenchContext): for _ in range(3): - # this would previously timeout if the output layout was empty (as is in this case) - yield from self.consumer.action.call(data=0) + await self.consumer.action.call(sim, data=0) for expected in [4, 1, 0]: - obtained = (yield from self.echo.action.call(data=expected))["data"] + obtained = (await self.echo.action.call(sim, data=expected)).data assert expected == obtained def test_single(self): @@ -59,4 +58,4 @@ def test_single(self): self.m = ModuleConnector(echo=self.echo, consumer=self.consumer) with self.run_simulation(self.m, max_cycles=100) as sim: - sim.add_process(self.proc) + sim.add_testbench(self.proc) diff --git a/test/test_connectors.py b/test/test_connectors.py index ac15a9f..e147a2f 100644 --- a/test/test_connectors.py +++ b/test/test_connectors.py @@ -1,10 +1,8 @@ import random from parameterized import parameterized_class -from amaranth.sim import Settle, Tick - from transactron.lib import StableSelectingNetwork -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, TestbenchContext @parameterized_class( @@ -19,7 +17,7 @@ def test(self): random.seed(42) - def process(): + async def process(sim: TestbenchContext): for _ in range(100): inputs = [random.randrange(2**8) for _ in range(self.n)] valids = [random.randrange(2) for _ in range(self.n)] @@ -27,20 +25,18 @@ def process(): expected_output_prefix = [] for i in range(self.n): - yield m.valids[i].eq(valids[i]) - yield m.inputs[i].data.eq(inputs[i]) + sim.set(m.valids[i], valids[i]) + sim.set(m.inputs[i].data, inputs[i]) if valids[i]: expected_output_prefix.append(inputs[i]) - yield Settle() - for i in range(total): - out = yield m.outputs[i].data + out = sim.get(m.outputs[i].data) assert out == expected_output_prefix[i] - assert (yield m.output_cnt) == total - yield Tick() + assert sim.get(m.output_cnt) == total + await sim.tick() with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) diff --git a/test/test_methods.py b/test/test_methods.py index e03ae5f..e4a5ced 100644 --- a/test/test_methods.py +++ b/test/test_methods.py @@ -154,14 +154,14 @@ def definition(idx: int, foo: Value): circuit = SimpleTestCircuit(TestDefMethods.CircuitTestModule(definition)) - def test_process(): + async def test_process(sim): for k, method in enumerate(circuit.methods): val = random.randrange(0, 2**3) - ret = yield from method.call(foo=val) + ret = await method.call(sim, foo=val) assert ret["foo"] == (val + k) % 2**3 with self.run_simulation(circuit) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) class AdapterCircuit(Elaboratable): @@ -392,13 +392,13 @@ class TestQuadrupleCircuits(TestCaseWithSimulator): def test(self, quadruple): circ = QuadrupleCircuit(quadruple()) - def process(): + async def process(sim): for n in range(1 << (WIDTH - 2)): - out = yield from circ.tb.call(data=n) + out = await circ.tb.call(sim, data=n) assert out["data"] == n * 4 with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class ConditionalCallCircuit(Elaboratable): @@ -483,31 +483,27 @@ class TestConditionals(TestCaseWithSimulator): def test_conditional_call(self): circ = ConditionalCallCircuit() - def process(): - yield from circ.out.disable() - yield from circ.tb.call_init(data=0) - yield Settle() - assert not (yield from circ.out.done()) - assert not (yield from circ.tb.done()) + async def process(sim): + circ.out.disable(sim) + circ.tb.call_init(sim, data=0) + *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) + assert not out_done and not tb_done - yield from circ.out.enable() - yield Settle() - assert not (yield from circ.out.done()) - assert (yield from circ.tb.done()) + circ.out.enable(sim) + *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) + assert not out_done and tb_done - yield from circ.tb.call_init(data=1) - yield Settle() - assert (yield from circ.out.done()) - assert (yield from circ.tb.done()) + circ.tb.call_init(sim, data=1) + *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) + assert out_done and tb_done # the argument is still 1 but the method is not called - yield from circ.tb.disable() - yield Settle() - assert not (yield from circ.out.done()) - assert not (yield from circ.tb.done()) + circ.tb.disable(sim) + *_, out_done, tb_done = await sim.tick().sample(circ.out.adapter.done, circ.tb.adapter.done) + assert not out_done and not tb_done with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) @parameterized.expand( [ @@ -520,18 +516,18 @@ def process(): def test_conditional(self, elaboratable): circ = elaboratable() - def process(): - yield from circ.tb.enable() - yield circ.ready.eq(0) - yield Settle() - assert not (yield from circ.tb.done()) + async def process(sim): + circ.tb.enable(sim) + sim.set(circ.ready, 0) + *_, tb_done = await sim.tick().sample(circ.tb.adapter.done) + assert not tb_done - yield circ.ready.eq(1) - yield Settle() - assert (yield from circ.tb.done()) + sim.set(circ.ready, 1) + *_, tb_done = await sim.tick().sample(circ.tb.adapter.done) + assert tb_done with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class NonexclusiveMethodCircuit(Elaboratable): @@ -559,42 +555,33 @@ class TestNonexclusiveMethod(TestCaseWithSimulator): def test_nonexclusive_method(self): circ = NonexclusiveMethodCircuit() - def process(): + async def process(sim): for x in range(8): t1en = bool(x & 1) t2en = bool(x & 2) mrdy = bool(x & 4) - if t1en: - yield from circ.t1.enable() - else: - yield from circ.t1.disable() + circ.t1.set_enable(sim, t1en) + circ.t2.set_enable(sim, t2en) + sim.set(circ.ready, int(mrdy)) + sim.set(circ.data, x) - if t2en: - yield from circ.t2.enable() - else: - yield from circ.t2.disable() - - if mrdy: - yield circ.ready.eq(1) - else: - yield circ.ready.eq(0) - - yield circ.data.eq(x) - yield Settle() + *_, running, t1_done, t2_done, t1_outputs, t2_outputs = await sim.delay(1e-9).sample( + circ.running, circ.t1.done, circ.t2.done, circ.t1.outputs, circ.t2.outputs + ) - assert bool((yield circ.running)) == ((t1en or t2en) and mrdy) - assert bool((yield from circ.t1.done())) == (t1en and mrdy) - assert bool((yield from circ.t2.done())) == (t2en and mrdy) + assert bool(running) == ((t1en or t2en) and mrdy) + assert bool(t1_done) == (t1en and mrdy) + assert bool(t2_done) == (t2en and mrdy) if t1en and mrdy: - assert (yield from circ.t1.get_outputs()) == {"data": x} + assert t1_outputs["data"] == x if t2en and mrdy: - assert (yield from circ.t2.get_outputs()) == {"data": x} + assert t2_outputs["data"] == x with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class TwoNonexclusiveConflictCircuit(Elaboratable): @@ -638,15 +625,15 @@ class TestConflicting(TestCaseWithSimulator): def test_conflicting(self, test_circuit: Callable[[], TwoNonexclusiveConflictCircuit]): circ = test_circuit() - def process(): - yield from circ.t1.enable() - yield from circ.t2.enable() - yield Settle() + async def process(sim): + circ.t1.enable(sim) + circ.t2.enable(sim) + *_, running1, running2 = await sim.delay(1e-9).sample(circ.running1, circ.running2) - assert not (yield circ.running1) or not (yield circ.running2) + assert not running1 or not running2 with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class CustomCombinerMethodCircuit(Elaboratable): @@ -679,7 +666,7 @@ class TestCustomCombinerMethod(TestCaseWithSimulator): def test_custom_combiner_method(self): circ = CustomCombinerMethodCircuit() - def process(): + async def process(sim): for x in range(8): t1en = bool(x & 1) t2en = bool(x & 2) @@ -690,38 +677,30 @@ def process(): val1e = val1 if t1en else 0 val2e = val2 if t2en else 0 - yield from circ.t1.call_init(data=val1) - yield from circ.t2.call_init(data=val2) + circ.t1.call_init(sim, data=val1) + circ.t2.call_init(sim, data=val2) - if t1en: - yield from circ.t1.enable() - else: - yield from circ.t1.disable() + circ.t1.set_enable(sim, t1en) + circ.t2.set_enable(sim, t2en) - if t2en: - yield from circ.t2.enable() - else: - yield from circ.t2.disable() + sim.set(circ.ready, int(mrdy)) - if mrdy: - yield circ.ready.eq(1) - else: - yield circ.ready.eq(0) - - yield Settle() + *_, running, t1_done, t2_done, t1_outputs, t2_outputs = await sim.delay(1e-9).sample( + circ.running, circ.t1.done, circ.t2.done, circ.t1.outputs, circ.t2.outputs + ) - assert bool((yield circ.running)) == ((t1en or t2en) and mrdy) - assert bool((yield from circ.t1.done())) == (t1en and mrdy) - assert bool((yield from circ.t2.done())) == (t2en and mrdy) + assert bool(running) == ((t1en or t2en) and mrdy) + assert bool(t1_done) == (t1en and mrdy) + assert bool(t2_done) == (t2en and mrdy) if t1en and mrdy: - assert (yield from circ.t1.get_outputs()) == {"data": val1e ^ val2e} + assert t1_outputs["data"] == val1e ^ val2e if t2en and mrdy: - assert (yield from circ.t2.get_outputs()) == {"data": val1e ^ val2e} + assert t2_outputs["data"] == val1e ^ val2e with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class DataDependentConditionalCircuit(Elaboratable): @@ -729,8 +708,8 @@ def __init__(self, n=2, ready_function=lambda arg: arg.data != 3): self.method = Method(i=data_layout(n)) self.ready_function = ready_function - self.in_t1 = Signal(from_method_layout(data_layout(n))) - self.in_t2 = Signal(from_method_layout(data_layout(n))) + self.in_t1 = Signal(n) + self.in_t2 = Signal(n) self.ready = Signal() self.req_t1 = Signal() self.req_t2 = Signal() @@ -748,11 +727,11 @@ def _(data): with Transaction().body(m, request=self.req_t1): m.d.comb += self.out_t1.eq(1) - self.method(m, self.in_t1) + self.method(m, data=self.in_t1) with Transaction().body(m, request=self.req_t2): m.d.comb += self.out_t2.eq(1) - self.method(m, self.in_t2) + self.method(m, data=self.in_t2) return m @@ -767,7 +746,7 @@ def base_random(self, f): random.seed(14) self.circ = DataDependentConditionalCircuit(n=self.n, ready_function=f) - def process(): + async def process(sim): for _ in range(self.test_number): in1 = random.randrange(0, 2**self.n) in2 = random.randrange(0, 2**self.n) @@ -775,16 +754,15 @@ def process(): req_t1 = random.randrange(2) req_t2 = random.randrange(2) - yield self.circ.in_t1.eq(in1) - yield self.circ.in_t2.eq(in2) - yield self.circ.req_t1.eq(req_t1) - yield self.circ.req_t2.eq(req_t2) - yield self.circ.ready.eq(m_ready) - yield Settle() + sim.set(self.circ.in_t1, in1) + sim.set(self.circ.in_t2, in2) + sim.set(self.circ.req_t1, req_t1) + sim.set(self.circ.req_t2, req_t2) + sim.set(self.circ.ready, m_ready) - out_m = yield self.circ.out_m - out_t1 = yield self.circ.out_t1 - out_t2 = yield self.circ.out_t2 + *_, out_m, out_t1, out_t2 = await sim.delay(1e-9).sample( + self.circ.out_m, self.circ.out_t1, self.circ.out_t2 + ) if not m_ready or (not req_t1 or in1 == self.bad_number) and (not req_t2 or in2 == self.bad_number): assert out_m == 0 @@ -800,10 +778,10 @@ def process(): assert in1 != self.bad_number or not out_t1 assert in2 != self.bad_number or not out_t2 - yield Tick() + await sim.tick() with self.run_simulation(self.circ, 100) as sim: - sim.add_process(process) + sim.add_testbench(process) def test_random_arg(self): self.base_random(lambda arg: arg.data != self.bad_number) diff --git a/test/test_metrics.py b/test/test_metrics.py index c7a0f07..181dcc8 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -7,12 +7,11 @@ from parameterized import parameterized_class from amaranth import * -from amaranth.sim import Settle, Tick from transactron.lib.metrics import * from transactron import * -from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit -from transactron.testing.infrastructure import Now +from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, TestbenchContext +from transactron.testing.tick_count import TicksKey from transactron.utils.dependencies import DependencyContext @@ -74,72 +73,62 @@ def test_counter_in_method(self): m = SimpleTestCircuit(CounterInMethodCircuit()) DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - def test_process(): + async def test_process(sim): called_cnt = 0 for _ in range(200): call_now = random.randint(0, 1) == 0 if call_now: - yield from m.method.call() + await m.method.call(sim) + called_cnt += 1 else: - yield Tick() - - # Note that it takes one cycle to update the register value, so here - # we are comparing the "previous" values. - assert called_cnt == (yield m._dut.counter.count.value) + await sim.tick() - if call_now: - called_cnt += 1 + assert called_cnt == sim.get(m._dut.counter.count.value) with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) def test_counter_with_condition_in_method(self): m = SimpleTestCircuit(CounterWithConditionInMethodCircuit()) DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - def test_process(): + async def test_process(sim): called_cnt = 0 for _ in range(200): call_now = random.randint(0, 1) == 0 condition = random.randint(0, 1) if call_now: - yield from m.method.call(cond=condition) + await m.method.call(sim, cond=condition) + called_cnt += condition else: - yield Tick() + await sim.tick() - # Note that it takes one cycle to update the register value, so here - # we are comparing the "previous" values. - assert called_cnt == (yield m._dut.counter.count.value) - - if call_now and condition == 1: - called_cnt += 1 + assert called_cnt == sim.get(m._dut.counter.count.value) with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) def test_counter_with_condition_without_method(self): m = CounterWithoutMethodCircuit() DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - def test_process(): + async def test_process(sim): called_cnt = 0 for _ in range(200): condition = random.randint(0, 1) - yield m.cond.eq(condition) - yield Tick() - - # Note that it takes one cycle to update the register value, so here - # we are comparing the "previous" values. - assert called_cnt == (yield m.counter.count.value) + sim.set(m.cond, condition) + await sim.tick() if condition == 1: called_cnt += 1 + assert called_cnt == sim.get(m.counter.count.value) + with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) class OneHotEnum(IntFlag): @@ -184,23 +173,23 @@ def do_test_enum(self, tags: range | Type[Enum] | list[int], tag_values: list[in for i in tag_values: counts[i] = 0 - def test_process(): + async def test_process(sim): for _ in range(200): for i in tag_values: - assert counts[i] == (yield m.counter.counters[i].value) + assert counts[i] == sim.get(m.counter.counters[i].value) tag = random.choice(list(tag_values)) - yield m.cond.eq(1) - yield m.tag.eq(tag) - yield Tick() - yield m.cond.eq(0) - yield Tick() + sim.set(m.cond, 1) + sim.set(m.tag, tag) + await sim.tick() + sim.set(m.cond, 0) + await sim.tick() counts[tag] += 1 with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) def test_one_hot_enum(self): self.do_test_enum(OneHotEnum, [e.value for e in OneHotEnum]) @@ -261,8 +250,8 @@ def test_histogram(self): max_sample_value = 2**self.sample_width - 1 - def test_process(): - min = max_sample_value + 1 + async def test_process(sim): + min = max_sample_value max = 0 sum = 0 count = 0 @@ -282,46 +271,43 @@ def test_process(): if value < 2**i or i == self.bucket_count - 1: buckets[i] += 1 break - yield from m.method.call(data=value) - yield Tick() + await m.method.call(sim, data=value) else: - yield Tick() + await sim.tick() histogram = m._dut.histogram - # Skip the assertion if the min is still uninitialized - if min != max_sample_value + 1: - assert min == (yield histogram.min.value) - assert max == (yield histogram.max.value) - assert sum == (yield histogram.sum.value) - assert count == (yield histogram.count.value) + assert min == sim.get(histogram.min.value) + assert max == sim.get(histogram.max.value) + assert sum == sim.get(histogram.sum.value) + assert count == sim.get(histogram.count.value) total_count = 0 for i in range(self.bucket_count): - bucket_value = yield histogram.buckets[i].value + bucket_value = sim.get(histogram.buckets[i].value) total_count += bucket_value assert buckets[i] == bucket_value # Sanity check if all buckets sum up to the total count value - assert total_count == (yield histogram.count.value) + assert total_count == sim.get(histogram.count.value) with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) class TestLatencyMeasurerBase(TestCaseWithSimulator): - def check_latencies(self, m: SimpleTestCircuit, latencies: list[int]): - assert min(latencies) == (yield m._dut.histogram.min.value) - assert max(latencies) == (yield m._dut.histogram.max.value) - assert sum(latencies) == (yield m._dut.histogram.sum.value) - assert len(latencies) == (yield m._dut.histogram.count.value) + def check_latencies(self, sim, m: SimpleTestCircuit, latencies: list[int]): + assert min(latencies) == sim.get(m._dut.histogram.min.value) + assert max(latencies) == sim.get(m._dut.histogram.max.value) + assert sum(latencies) == sim.get(m._dut.histogram.sum.value) + assert len(latencies) == sim.get(m._dut.histogram.count.value) for i in range(m._dut.histogram.bucket_count): bucket_start = 0 if i == 0 else 2 ** (i - 1) bucket_end = 1e10 if i == m._dut.histogram.bucket_count - 1 else 2**i count = sum(1 for x in latencies if bucket_start <= x < bucket_end) - assert count == (yield m._dut.histogram.buckets[i].value) + assert count == sim.get(m._dut.histogram.buckets[i].value) @parameterized_class( @@ -351,36 +337,33 @@ def test_latency_measurer(self): finish = False - def producer(): + async def producer(sim: TestbenchContext): nonlocal finish + ticks = DependencyContext.get().get_dependency(TicksKey()) for _ in range(200): - yield from m._start.call() + await m._start.call(sim) - # Make sure that the time is updated first. - yield Settle() - time = yield Now() - event_queue.put(time) - yield from self.random_wait_geom(0.8) + event_queue.put(sim.get(ticks)) + await self.random_wait_geom(sim, 0.8) finish = True - def consumer(): + async def consumer(sim: TestbenchContext): + ticks = DependencyContext.get().get_dependency(TicksKey()) + while not finish: - yield from m._stop.call() + await m._stop.call(sim) - # Make sure that the time is updated first. - yield Settle() - time = yield Now() - latencies.append(time - event_queue.get()) + latencies.append(sim.get(ticks) - event_queue.get()) - yield from self.random_wait_geom(1.0 / self.expected_consumer_wait) + await self.random_wait_geom(sim, 1.0 / self.expected_consumer_wait) - self.check_latencies(m, latencies) + self.check_latencies(sim, m, latencies) with self.run_simulation(m) as sim: - sim.add_process(producer) - sim.add_process(consumer) + sim.add_testbench(producer) + sim.add_testbench(consumer) @parameterized_class( @@ -412,54 +395,51 @@ def test_latency_measurer(self): finish = False - def producer(): + async def producer(sim): nonlocal finish + tick_count = DependencyContext.get().get_dependency(TicksKey()) + for _ in range(200): while not free_slots: - yield Tick() - continue - yield Settle() + await sim.tick() + await sim.delay(1e-9) slot_id = random.choice(free_slots) - yield from m._start.call(slot=slot_id) + await m._start.call(sim, slot=slot_id) - time = yield Now() - - events[slot_id] = time + events[slot_id] = sim.get(tick_count) free_slots.remove(slot_id) used_slots.append(slot_id) - yield from self.random_wait_geom(0.8) + await self.random_wait_geom(sim, 0.8) finish = True - def consumer(): + async def consumer(sim): + tick_count = DependencyContext.get().get_dependency(TicksKey()) + while not finish: while not used_slots: - yield Tick() - continue + await sim.tick() slot_id = random.choice(used_slots) - yield from m._stop.call(slot=slot_id) - - time = yield Now() + await m._stop.call(sim, slot=slot_id) - yield Settle() - yield Settle() + await sim.delay(2e-9) - latencies.append(time - events[slot_id]) + latencies.append(sim.get(tick_count) - events[slot_id]) used_slots.remove(slot_id) free_slots.append(slot_id) - yield from self.random_wait_geom(1.0 / self.expected_consumer_wait) + await self.random_wait_geom(sim, 1.0 / self.expected_consumer_wait) - self.check_latencies(m, latencies) + self.check_latencies(sim, m, latencies) with self.run_simulation(m) as sim: - sim.add_process(producer) - sim.add_process(consumer) + sim.add_testbench(producer) + sim.add_testbench(consumer) class MetricManagerTestCircuit(Elaboratable): @@ -527,21 +507,20 @@ def test_returned_reg_values(self): DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - def test_process(): + async def test_process(sim): counters = [0] * 3 for _ in range(200): rand = [random.randint(0, 1) for _ in range(3)] - yield from m.incr_counters.call(counter1=rand[0], counter2=rand[1], counter3=rand[2]) - yield Tick() + await m.incr_counters.call(sim, counter1=rand[0], counter2=rand[1], counter3=rand[2]) for i in range(3): if rand[i] == 1: counters[i] += 1 - assert counters[0] == (yield metrics_manager.get_register_value("foo.counter1", "count")) - assert counters[1] == (yield metrics_manager.get_register_value("bar.baz.counter2", "count")) - assert counters[2] == (yield metrics_manager.get_register_value("bar.baz.counter3", "count")) + assert counters[0] == sim.get(metrics_manager.get_register_value("foo.counter1", "count")) + assert counters[1] == sim.get(metrics_manager.get_register_value("bar.baz.counter2", "count")) + assert counters[2] == sim.get(metrics_manager.get_register_value("bar.baz.counter3", "count")) with self.run_simulation(m) as sim: - sim.add_process(test_process) + sim.add_testbench(test_process) diff --git a/test/test_simultaneous.py b/test/test_simultaneous.py index d085930..ad492e3 100644 --- a/test/test_simultaneous.py +++ b/test/test_simultaneous.py @@ -3,10 +3,12 @@ from typing import Optional from amaranth import * from amaranth.sim import * +from transactron.testing.method_mock import MethodMock, def_method_mock +from transactron.testing.testbenchio import TestbenchIO from transactron.utils import ModuleConnector -from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchIO, def_method_mock +from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchContext from transactron import * from transactron.lib import Adapter, Connect, ConnectTrans @@ -44,17 +46,17 @@ class TestSimultaneousDiamond(TestCaseWithSimulator): def test_diamond(self): circ = SimpleTestCircuit(SimultaneousDiamondTestCircuit()) - def process(): + async def process(sim: TestbenchContext): methods = {"l": circ.method_l, "r": circ.method_r, "u": circ.method_u, "d": circ.method_d} for i in range(1 << len(methods)): enables: dict[str, bool] = {} for k, n in enumerate(methods): enables[n] = bool(i & (1 << k)) - yield from methods[n].set_enable(enables[n]) - yield Tick() + methods[n].set_enable(sim, enables[n]) dones: dict[str, bool] = {} for n in methods: - dones[n] = bool((yield from methods[n].done())) + dones[n] = bool(methods[n].get_done(sim)) + await sim.tick() for n in methods: if not enables[n]: assert not dones[n] @@ -66,7 +68,7 @@ def process(): assert not any(dones.values()) with self.run_simulation(circ) as sim: - sim.add_process(process) + sim.add_testbench(process) class UnsatisfiableTriangleTestCircuit(Elaboratable): @@ -148,17 +150,19 @@ def test_transitivity(self): result: Optional[int] @def_method_mock(lambda: target) - def target_process(data): - nonlocal result - result = data + def target_process(data: int): + @MethodMock.effect + def eff(): + nonlocal result + result = data - def process(): + async def process(sim: TestbenchContext): nonlocal result for source, data, reqv1, reqv2 in product([circ.source1, circ.source2], [0, 1, 2, 3], [0, 1], [0, 1]): result = None - yield req1.eq(reqv1) - yield req2.eq(reqv2) - call_result = yield from source.call_try(data=data) + sim.set(req1, reqv1) + sim.set(req2, reqv2) + call_result = await source.call_try(sim, data=data) if not reqv1 and not reqv2: assert call_result is None @@ -169,4 +173,4 @@ def process(): assert result in possibles with self.run_simulation(m) as sim: - sim.add_process(process) + sim.add_testbench(process) diff --git a/test/test_transactron_lib_storage.py b/test/test_transactron_lib_storage.py index 404c14a..d5513fe 100644 --- a/test/test_transactron_lib_storage.py +++ b/test/test_transactron_lib_storage.py @@ -31,28 +31,24 @@ def generic_process( settle_count=0, name="", ): - def f(): + async def f(sim: TestbenchContext): while input_lst: # wait till all processes will end the previous cycle - yield from self.multi_settle(4) + await sim.delay(1e-9) elem = input_lst.pop() if isinstance(elem, OpNOP): - yield Tick() + await sim.tick() continue if input_verification is not None and not input_verification(elem): - yield Tick() + await sim.tick() continue - response = yield from method.call(**elem) - yield from self.multi_settle(settle_count) + response = await method.call(sim, **elem) + await sim.delay(settle_count * 1e-9) if behaviour_check is not None: - # Here accesses to circuit are allowed - ret = behaviour_check(elem, response) - if isinstance(ret, Generator): - yield from ret + behaviour_check(elem, response) if state_change is not None: - # It is standard python function by purpose to don't allow accessing circuit state_change(elem, response) - yield Tick() + await sim.tick() return f @@ -77,10 +73,10 @@ def check(elem, response): addr = elem["addr"] frozen_addr = frozenset(addr.items()) if frozen_addr in self.memory: - assert response["not_found"] == 0 - assert response["data"] == self.memory[frozen_addr] + assert response.not_found == 0 + assert data_const_to_dict(response.data) == self.memory[frozen_addr] else: - assert response["not_found"] == 1 + assert response.not_found == 1 return self.generic_process(self.circ.read, in_read, behaviour_check=check, settle_count=0, name="read") @@ -97,7 +93,7 @@ def verify_in(elem): return ret def check(elem, response): - assert response["not_found"] == int(frozenset(elem["addr"].items()) not in self.memory) + assert response.not_found == int(frozenset(elem["addr"].items()) not in self.memory) def modify_state(elem, response): if frozenset(elem["addr"].items()) in self.memory: @@ -129,7 +125,7 @@ def test_random(self, in_push, in_write, in_read, in_remove): with self.reinitialize_fixtures(): self.setUp() with self.run_simulation(self.circ, max_cycles=500) as sim: - sim.add_process(self.push_process(in_push)) - sim.add_process(self.read_process(in_read)) - sim.add_process(self.write_process(in_write)) - sim.add_process(self.remove_process(in_remove)) + sim.add_testbench(self.push_process(in_push)) + sim.add_testbench(self.read_process(in_read)) + sim.add_testbench(self.write_process(in_write)) + sim.add_testbench(self.remove_process(in_remove)) diff --git a/test/testing/test_infrastructure.py b/test/testing/test_infrastructure.py deleted file mode 100644 index cfd59ec..0000000 --- a/test/testing/test_infrastructure.py +++ /dev/null @@ -1,31 +0,0 @@ -from amaranth import * -from transactron.testing import * - - -class EmptyCircuit(Elaboratable): - def __init__(self): - pass - - def elaborate(self, platform): - m = Module() - return m - - -class TestNow(TestCaseWithSimulator): - def setup_method(self): - self.test_cycles = 10 - self.m = SimpleTestCircuit(EmptyCircuit()) - - def process(self): - for k in range(self.test_cycles): - now = yield Now() - assert k == now - # check if second call don't change the returned value - now = yield Now() - assert k == now - - yield Tick() - - def test_random(self): - with self.run_simulation(self.m, 50) as sim: - sim.add_process(self.process) diff --git a/test/testing/test_log.py b/test/testing/test_log.py index 6e6711d..97efb9d 100644 --- a/test/testing/test_log.py +++ b/test/testing/test_log.py @@ -1,10 +1,9 @@ import pytest import re from amaranth import * -from amaranth.sim import Tick from transactron import * -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, TestbenchContext from transactron.lib import logging LOGGER_NAME = "test_logger" @@ -70,13 +69,13 @@ class TestLog(TestCaseWithSimulator): def test_log(self, caplog): m = LogTest() - def proc(): + async def proc(sim: TestbenchContext): for i in range(50): - yield Tick() - yield m.input.eq(i) + await sim.tick() + sim.set(m.input, i) with self.run_simulation(m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) assert re.search( r"WARNING test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] " @@ -93,13 +92,13 @@ def proc(): def test_error_log(self, caplog): m = ErrorLogTest() - def proc(): - yield Tick() - yield m.input.eq(1) + async def proc(sim: TestbenchContext): + await sim.tick() + sim.set(m.input, 1) with pytest.raises(AssertionError): with self.run_simulation(m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) assert re.search( r"ERROR test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] " @@ -110,13 +109,13 @@ def proc(): def test_assertion(self, caplog): m = AssertionTest() - def proc(): - yield Tick() - yield m.input.eq(1) + async def proc(sim: TestbenchContext): + await sim.tick() + sim.set(m.input, 1) with pytest.raises(AssertionError): with self.run_simulation(m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) assert re.search( r"ERROR test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] Output differs", diff --git a/test/testing/test_validate_arguments.py b/test/testing/test_validate_arguments.py index 18066ff..7e70369 100644 --- a/test/testing/test_validate_arguments.py +++ b/test/testing/test_validate_arguments.py @@ -2,11 +2,12 @@ from amaranth import * from amaranth.sim import * -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout +from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout, TestbenchContext from transactron import * -from transactron.testing.sugar import def_method_mock +from transactron.testing.method_mock import def_method_mock from transactron.lib import * +from transactron.testing.testbenchio import CallTrigger class ValidateArgumentsTestCircuit(Elaboratable): @@ -24,25 +25,26 @@ def elaborate(self, platform): class TestValidateArguments(TestCaseWithSimulator): def control_caller(self, caller: TestbenchIO, method: TestbenchIO): - def process(): + async def process(sim: TestbenchContext): + await sim.tick() for _ in range(100): val = random.randrange(2) pre_accepted_val = self.accepted_val - ret = yield from caller.call_try(data=val) - if ret is None: - assert val != pre_accepted_val or val == pre_accepted_val and (yield from method.done()) - else: + caller_data, method_data = await CallTrigger(sim).call(caller, data=val).sample(method) + if caller_data is not None: assert val == pre_accepted_val - assert ret["data"] == val + assert caller_data.data == val + else: + assert val != pre_accepted_val or val == pre_accepted_val and method_data is not None return process def validate_arguments(self, data: int): return data == self.accepted_val - def changer(self): + async def changer(self, sim: TestbenchContext): for _ in range(50): - yield Tick("sync_neg") + await sim.tick() self.accepted_val = 1 @def_method_mock(tb_getter=lambda self: self.m.method, validate_arguments=validate_arguments) @@ -54,6 +56,6 @@ def test_validate_arguments(self): self.m = ValidateArgumentsTestCircuit() self.accepted_val = 0 with self.run_simulation(self.m) as sim: - sim.add_process(self.changer) - sim.add_process(self.control_caller(self.m.caller1, self.m.method)) - sim.add_process(self.control_caller(self.m.caller2, self.m.method)) + sim.add_testbench(self.changer) + sim.add_testbench(self.control_caller(self.m.caller1, self.m.method)) + sim.add_testbench(self.control_caller(self.m.caller2, self.m.method)) diff --git a/test/utils/test_amaranth_ext.py b/test/utils/test_amaranth_ext.py index 7943ccb..349fc0b 100644 --- a/test/utils/test_amaranth_ext.py +++ b/test/utils/test_amaranth_ext.py @@ -1,5 +1,6 @@ from transactron.testing import * import random +import pytest from transactron.utils.amaranth_ext import MultiPriorityEncoder, RingMultiPriorityEncoder @@ -33,26 +34,25 @@ def get_expected_ring(input_width, output_count, input, first, last): ) class TestPriorityEncoder(TestCaseWithSimulator): def process(self, get_expected): - def f(): + async def f(sim: TestbenchContext): for _ in range(self.test_number): input = random.randrange(2**self.input_width) first = random.randrange(self.input_width) last = random.randrange(self.input_width) - yield self.circ.input.eq(input) + sim.set(self.circ.input, input) try: - yield self.circ.first.eq(first) - yield self.circ.last.eq(last) + sim.set(self.circ.first, first) + sim.set(self.circ.last, last) except AttributeError: pass - yield Settle() expected_output = get_expected(self.input_width, self.output_count, input, first, last) for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids): if ex is None: - assert (yield valid) == 0 + assert sim.get(valid) == 0 else: - assert (yield valid) == 1 - assert (yield real) == ex - yield Delay(1e-7) + assert sim.get(valid) == 1 + assert sim.get(real) == ex + await sim.delay(1e-7) return f @@ -66,7 +66,7 @@ def test_random(self, test_class, verif_f, input_width, output_count): self.circ = test_class(self.input_width, self.output_count) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process(verif_f)) + sim.add_testbench(self.process(verif_f)) @pytest.mark.parametrize("name", ["prio_encoder", None]) def test_static_create_simple(self, test_class, verif_f, name): @@ -100,7 +100,7 @@ def elaborate(self, platform): self.circ = DUT(self.input_width, self.output_count, name) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process(verif_f)) + sim.add_testbench(self.process(verif_f)) @pytest.mark.parametrize("name", ["prio_encoder", None]) def test_static_create(self, test_class, verif_f, name): @@ -132,4 +132,4 @@ def elaborate(self, platform): self.circ = DUT(self.input_width, self.output_count, name) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process(verif_f)) + sim.add_testbench(self.process(verif_f)) diff --git a/test/utils/test_onehotswitch.py b/test/utils/test_onehotswitch.py index 9d7dc84..b0620c0 100644 --- a/test/utils/test_onehotswitch.py +++ b/test/utils/test_onehotswitch.py @@ -3,7 +3,7 @@ from transactron.utils import OneHotSwitch -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, TestbenchContext from parameterized import parameterized @@ -30,33 +30,30 @@ def elaborate(self, platform): return m -class TestAssign(TestCaseWithSimulator): +class TestOneHotSwitch(TestCaseWithSimulator): @parameterized.expand([(False,), (True,)]) def test_onehotswitch(self, test_zero): circuit = OneHotSwitchCircuit(4, test_zero) - def switch_test_proc(): + async def switch_test_proc(sim: TestbenchContext): for i in range(len(circuit.input)): - yield circuit.input.eq(1 << i) - yield Settle() - assert (yield circuit.output) == i + sim.set(circuit.input, 1 << i) + assert sim.get(circuit.output) == i with self.run_simulation(circuit) as sim: - sim.add_process(switch_test_proc) + sim.add_testbench(switch_test_proc) def test_onehotswitch_zero(self): circuit = OneHotSwitchCircuit(4, True) - def switch_test_proc_zero(): + async def switch_test_proc_zero(sim: TestbenchContext): for i in range(len(circuit.input)): - yield circuit.input.eq(1 << i) - yield Settle() - assert (yield circuit.output) == i - assert not (yield circuit.zero) + sim.set(circuit.input, 1 << i) + assert sim.get(circuit.output) == i + assert not sim.get(circuit.zero) - yield circuit.input.eq(0) - yield Settle() - assert (yield circuit.zero) + sim.set(circuit.input, 0) + assert sim.get(circuit.zero) with self.run_simulation(circuit) as sim: - sim.add_process(switch_test_proc_zero) + sim.add_testbench(switch_test_proc_zero) diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py index 63c1761..abd28f4 100644 --- a/test/utils/test_utils.py +++ b/test/utils/test_utils.py @@ -78,21 +78,21 @@ def setup_method(self): self.test_number = 40 self.m = PopcountTestCircuit(self.size) - def check(self, n): - yield self.m.sig_in.eq(n) - yield Settle() - out_popcount = yield self.m.sig_out + def check(self, sim: TestbenchContext, n): + sim.set(self.m.sig_in, n) + out_popcount = sim.get(self.m.sig_out) assert out_popcount == n.bit_count(), f"{n:x}" - def process(self): + async def process(self, sim: TestbenchContext): for i in range(self.test_number): n = random.randrange(2**self.size) - yield from self.check(n) - yield from self.check(2**self.size - 1) + self.check(sim, n) + sim.delay(1e-6) + self.check(sim, 2**self.size - 1) def test_popcount(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) class CLZTestCircuit(Elaboratable): @@ -124,21 +124,21 @@ def setup_method(self): self.test_number = 40 self.m = CLZTestCircuit(self.size) - def check(self, n): - yield self.m.sig_in.eq(n) - yield Settle() - out_clz = yield self.m.sig_out + def check(self, sim: TestbenchContext, n): + sim.set(self.m.sig_in, n) + out_clz = sim.get(self.m.sig_out) assert out_clz == (2**self.size) - n.bit_length(), f"{n:x}" - def process(self): + async def process(self, sim: TestbenchContext): for i in range(self.test_number): n = random.randrange(2**self.size) - yield from self.check(n) - yield from self.check(2**self.size - 1) + self.check(sim, n) + sim.delay(1e-6) + self.check(sim, 2**self.size - 1) def test_count_leading_zeros(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) class CTZTestCircuit(Elaboratable): @@ -170,10 +170,9 @@ def setup_method(self): self.test_number = 40 self.m = CTZTestCircuit(self.size) - def check(self, n): - yield self.m.sig_in.eq(n) - yield Settle() - out_ctz = yield self.m.sig_out + def check(self, sim: TestbenchContext, n): + sim.set(self.m.sig_in, n) + out_ctz = sim.get(self.m.sig_out) expected = 0 if n == 0: @@ -185,12 +184,13 @@ def check(self, n): assert out_ctz == expected, f"{n:x}" - def process(self): + async def process(self, sim: TestbenchContext): for i in range(self.test_number): n = random.randrange(2**self.size) - yield from self.check(n) - yield from self.check(2**self.size - 1) + self.check(sim, n) + await sim.delay(1e-6) + self.check(sim, 2**self.size - 1) def test_count_trailing_zeros(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py index cd09816..91b94f9 100644 --- a/transactron/lib/transformers.py +++ b/transactron/lib/transformers.py @@ -166,6 +166,7 @@ def __init__( use_condition : bool Instead of `m.If` use simultaneus `condition` which allow to execute this filter if the condition is False and target is not ready. + When `use_condition` is true, `condition` must not be a `Method`. src_loc: int | SrcLoc How many stack frames deep the source location is taken from. Alternatively, the source location to use instead of the default. @@ -180,6 +181,8 @@ def __init__( self.condition = condition self.default = default + assert not (use_condition and isinstance(condition, Method)) + def elaborate(self, platform): m = TModule() diff --git a/transactron/testing/__init__.py b/transactron/testing/__init__.py index aa21522..8c49400 100644 --- a/transactron/testing/__init__.py +++ b/transactron/testing/__init__.py @@ -1,8 +1,10 @@ +from amaranth.sim._async import TestbenchContext, ProcessContext, SimulatorContext # noqa: F401 from .input_generation import * # noqa: F401 from .functions import * # noqa: F401 from .infrastructure import * # noqa: F401 -from .sugar import * # noqa: F401 +from .method_mock import * # noqa: F401 from .testbenchio import * # noqa: F401 from .profiler import * # noqa: F401 from .logging import * # noqa: F401 +from .tick_count import * # noqa: F401 from transactron.utils import data_layout # noqa: F401 diff --git a/transactron/testing/functions.py b/transactron/testing/functions.py index 7cb69b1..ee12251 100644 --- a/transactron/testing/functions.py +++ b/transactron/testing/functions.py @@ -1,31 +1,15 @@ -from amaranth import * -from amaranth.lib.data import Layout, StructLayout, View -from amaranth.sim._pycoro import Command -from typing import TypeVar, Any, Generator, TypeAlias, TYPE_CHECKING, Union -from transactron.utils._typing import RecordIntDict +import amaranth.lib.data as data +from typing import TypeAlias -if TYPE_CHECKING: - from amaranth.hdl._ast import Statement - from .infrastructure import CoreblocksCommand +MethodData: TypeAlias = "data.Const[data.StructLayout]" -T = TypeVar("T") -TestGen: TypeAlias = Generator[Union[Command, Value, "Statement", "CoreblocksCommand", None], Any, T] - - -def get_outputs(field: View) -> TestGen[RecordIntDict]: - # return dict of all signal values in a record because amaranth's simulator can't read all - # values of a View in a single yield - it can only read Values (Signals) - result = {} - layout = field.shape() - assert isinstance(layout, StructLayout) - for name, fld in layout: - val = field[name] - if isinstance(fld.shape, Layout): - result[name] = yield from get_outputs(View(fld.shape, val)) - elif isinstance(val, Value): - result[name] = yield val - else: - raise ValueError - return result +def data_const_to_dict(c: "data.Const[data.Layout]"): + ret = {} + for k, _ in c.shape(): + v = c[k] + if isinstance(v, data.Const): + v = data_const_to_dict(v) + ret[k] = v + return ret diff --git a/transactron/testing/infrastructure.py b/transactron/testing/infrastructure.py index 861428f..a67e1cf 100644 --- a/transactron/testing/infrastructure.py +++ b/transactron/testing/infrastructure.py @@ -4,19 +4,20 @@ import os import random import functools -import warnings from contextlib import contextmanager, nullcontext -from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast, TypeAlias, Optional -from abc import ABC +from collections.abc import Callable +from typing import TypeVar, Generic, Type, TypeGuard, Any, cast, TypeAlias, Optional from amaranth import * from amaranth.sim import * +from amaranth.sim._async import SimulatorContext from transactron.utils.dependencies import DependencyContext, DependencyManager from .testbenchio import TestbenchIO from .profiler import profiler_process, Profile -from .functions import TestGen from .logging import make_logging_process, parse_logging_level, _LogFormatter +from .tick_count import make_tick_count_process from .gtkw_extension import write_vcd_ext +from .method_mock import MethodMock from transactron import Method from transactron.lib import AdapterTrans from transactron.core.keys import TransactionManagerKey @@ -24,6 +25,9 @@ from transactron.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals +__all__ = ["SimpleTestCircuit", "PysimSimulator", "TestCaseWithSimulator"] + + T = TypeVar("T") _T_nested_collection: TypeAlias = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"] @@ -56,7 +60,10 @@ def __getattr__(self, name: str) -> Any: def elaborate(self, platform): def transform_methods_to_testbenchios( container: _T_nested_collection[Method], - ) -> tuple[_T_nested_collection["TestbenchIO"], Union[ModuleConnector, "TestbenchIO"]]: + ) -> tuple[ + _T_nested_collection["TestbenchIO"], + "ModuleConnector | TestbenchIO", + ]: if isinstance(container, list): tb_list = [] mc_list = [] @@ -119,43 +126,6 @@ def elaborate(self, platform) -> HasElaborate: return m -class CoreblocksCommand(ABC): - pass - - -class Now(CoreblocksCommand): - pass - - -class SyncProcessWrapper: - def __init__(self, f): - self.org_process = f - self.current_cycle = 0 - - def _wrapping_function(self): - response = None - org_coroutine = self.org_process() - try: - while True: - # call orginal test process and catch data yielded by it in `command` variable - command = org_coroutine.send(response) - # If process wait for new cycle - if command is None or isinstance(command, Tick): - command = command or Tick() - # TODO: use of other domains can mess up the counter! - if command.domain == "sync": - self.current_cycle += 1 - # forward to amaranth - yield command - elif isinstance(command, Now): - response = self.current_cycle - # Pass everything else to amaranth simulator without modifications - else: - response = yield command - except StopIteration: - pass - - class PysimSimulator(Simulator): def __init__( self, @@ -197,10 +167,6 @@ def __init__( self.deadline = clk_period * max_cycles - def add_process(self, f: Callable[[], TestGen]): - f_wrapped = SyncProcessWrapper(f) - super().add_process(f_wrapped._wrapping_function) - def run(self) -> bool: with self.ctx: self.run_until(self.deadline) @@ -212,34 +178,44 @@ class TestCaseWithSimulator: dependency_manager: DependencyManager @contextmanager - def configure_dependency_context(self): + def _configure_dependency_context(self): self.dependency_manager = DependencyManager() with DependencyContext(self.dependency_manager): yield Tick() - def add_class_mocks(self, sim: PysimSimulator) -> None: + def add_mock(self, sim: PysimSimulator, val: MethodMock): + sim.add_process(val.output_process) + if val.validate_arguments is not None: + sim.add_process(val.validate_arguments_process) + sim.add_testbench(val.effect_process, background=True) + + def _add_class_mocks(self, sim: PysimSimulator) -> None: for key in dir(self): val = getattr(self, key) if hasattr(val, "_transactron_testing_process"): sim.add_process(val) + elif hasattr(val, "_transactron_method_mock"): + self.add_mock(sim, val()) - def add_local_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: + def _add_local_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: for key, val in frame_locals.items(): if hasattr(val, "_transactron_testing_process"): sim.add_process(val) + elif hasattr(val, "_transactron_method_mock"): + self.add_mock(sim, val()) - def add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: - self.add_class_mocks(sim) - self.add_local_mocks(sim, frame_locals) + def _add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: + self._add_class_mocks(sim) + self._add_local_mocks(sim, frame_locals) - def configure_traces(self): + def _configure_traces(self): traces_file = None if "__TRANSACTRON_DUMP_TRACES" in os.environ: traces_file = self._transactron_current_output_file_name self._transactron_infrastructure_traces_file = traces_file @contextmanager - def configure_profiles(self): + def _configure_profiles(self): profile = None if "__TRANSACTRON_PROFILE" in os.environ: @@ -264,7 +240,7 @@ def f(): profile.encode(f"{profile_dir}/{profile_file}.json") @contextmanager - def configure_logging(self): + def _configure_logging(self): def on_error(): assert False, "Simulation finished due to an error" @@ -292,10 +268,11 @@ def reinitialize_fixtures(self): self._transactron_base_output_file_name + "_" + str(self._transactron_hypothesis_iter_counter) ) self._transactron_sim_processes_to_add: list[Callable[[], Optional[Callable]]] = [] - with self.configure_dependency_context(): - self.configure_traces() - with self.configure_profiles(): - with self.configure_logging(): + with self._configure_dependency_context(): + self._configure_traces() + with self._configure_profiles(): + with self._configure_logging(): + self._transactron_sim_processes_to_add.append(make_tick_count_process) yield self._transactron_hypothesis_iter_counter += 1 @@ -323,7 +300,7 @@ def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_tra traces_file=self._transactron_infrastructure_traces_file, clk_period=clk_period, ) - self.add_all_mocks(sim, sys._getframe(2).f_locals) + self._add_all_mocks(sim, sys._getframe(2).f_locals) yield sim @@ -332,34 +309,25 @@ def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_tra if ret is not None: sim.add_process(ret) - with warnings.catch_warnings(): - # TODO: figure out testing without settles! - warnings.filterwarnings("ignore", r"The `Settle` command is deprecated per RFC 27\.") - - res = sim.run() + res = sim.run() assert res, "Simulation time limit exceeded" - def tick(self, cycle_cnt: int = 1): + async def tick(self, sim: SimulatorContext, cycle_cnt: int = 1): """ - Yields for the given number of cycles. + Waits for the given number of cycles. """ - for _ in range(cycle_cnt): - yield Tick() + await sim.tick() - def random_wait(self, max_cycle_cnt: int, *, min_cycle_cnt: int = 0): + async def random_wait(self, sim: SimulatorContext, max_cycle_cnt: int, *, min_cycle_cnt: int = 0): """ Wait for a random amount of cycles in range [min_cycle_cnt, max_cycle_cnt] """ - yield from self.tick(random.randrange(min_cycle_cnt, max_cycle_cnt + 1)) + await self.tick(sim, random.randrange(min_cycle_cnt, max_cycle_cnt + 1)) - def random_wait_geom(self, prob: float = 0.5): + async def random_wait_geom(self, sim: SimulatorContext, prob: float = 0.5): """ Wait till the first success, where there is `prob` probability for success in each cycle. """ while random.random() > prob: - yield Tick() - - def multi_settle(self, settle_count: int = 1): - for _ in range(settle_count): - yield Settle() + await sim.tick() diff --git a/transactron/testing/logging.py b/transactron/testing/logging.py index 7c8edf1..2a40d6f 100644 --- a/transactron/testing/logging.py +++ b/transactron/testing/logging.py @@ -1,9 +1,12 @@ -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any import logging +import itertools -from amaranth.sim import Passive, Tick +from amaranth.sim._async import ProcessContext from transactron.lib import logging as tlog +from transactron.utils.dependencies import DependencyContext +from .tick_count import TicksKey __all__ = ["make_logging_process", "parse_logging_level"] @@ -31,7 +34,7 @@ def parse_logging_level(str: str) -> tlog.LogLevel: raise ValueError("Log level must be either {error, warn, info, debug} or a non-negative integer.") -_sim_cycle = 0 +_sim_cycle: int = 0 class _LogFormatter(logging.Formatter): @@ -65,17 +68,15 @@ def make_logging_process(level: tlog.LogLevel, namespace_regexp: str, on_error: root_logger = logging.getLogger() - def handle_logs(): - if not (yield combined_trigger): - return + def handle_logs(record_vals: Iterable[int]) -> None: + it = iter(record_vals) for record in records: - if not (yield record.trigger): - continue + trigger = next(it) + values = [next(it) for _ in record.fields] - values: list[int] = [] - for field in record.fields: - values.append((yield field)) + if not trigger: + continue formatted_msg = record.format(*values) @@ -91,15 +92,18 @@ def handle_logs(): if record.level >= logging.ERROR: on_error() - def log_process(): + async def log_process(sim: ProcessContext) -> None: global _sim_cycle - _sim_cycle = 0 - - yield Passive() - while True: - yield Tick("sync_neg") - yield from handle_logs() - yield Tick() - _sim_cycle += 1 + ticks = DependencyContext.get().get_dependency(TicksKey()) + + async for _, _, ticks_val, combined_trigger_val, *record_vals in ( + sim.tick("sync_neg") + .sample(ticks, combined_trigger) + .sample(*itertools.chain(*([record.trigger] + record.fields for record in records))) + ): + if not combined_trigger_val: + continue + _sim_cycle = ticks_val + handle_logs(record_vals) return log_process diff --git a/transactron/testing/method_mock.py b/transactron/testing/method_mock.py new file mode 100644 index 0000000..9587ae1 --- /dev/null +++ b/transactron/testing/method_mock.py @@ -0,0 +1,175 @@ +from contextlib import contextmanager +import functools +from typing import Callable, Any, Optional + +from amaranth.sim._async import SimulatorContext +from transactron.lib.adapters import Adapter +from transactron.utils.transactron_helpers import async_mock_def_helper +from .testbenchio import TestbenchIO +from transactron.utils._typing import RecordIntDict + + +__all__ = ["MethodMock", "def_method_mock"] + + +class MethodMock: + def __init__( + self, + adapter: Adapter, + function: Callable[..., Optional[RecordIntDict]], + *, + validate_arguments: Optional[Callable[..., bool]] = None, + enable: Callable[[], bool] = lambda: True, + delay: float = 0, + ): + self.adapter = adapter + self.function = function + self.validate_arguments = validate_arguments + self.enable = enable + self.delay = delay + self._effects: list[Callable[[], None]] = [] + self._freeze = False + + _current_mock: Optional["MethodMock"] = None + + @staticmethod + def effect(effect: Callable[[], None]): + assert MethodMock._current_mock is not None + MethodMock._current_mock._effects.append(effect) + + @contextmanager + def _context(self): + assert MethodMock._current_mock is None + MethodMock._current_mock = self + try: + yield + finally: + MethodMock._current_mock = None + + async def output_process( + self, + sim: SimulatorContext, + ) -> None: + sync = sim._design.lookup_domain("sync", None) # type: ignore + async for *_, done, arg, clk in sim.changed(self.adapter.done, self.adapter.data_out).edge(sync.clk, 1): + if clk: + self._freeze = True + if not done or self._freeze: + continue + self._effects = [] + with self._context(): + ret = async_mock_def_helper(self, self.function, arg) + sim.set(self.adapter.data_in, ret) + + async def validate_arguments_process(self, sim: SimulatorContext) -> None: + assert self.validate_arguments is not None + sync = sim._design.lookup_domain("sync", None) # type: ignore + async for *args, clk, _ in ( + sim.changed(*(a for a, _ in self.adapter.validators)).edge(sync.clk, 1).edge(self.adapter.en, 1) + ): + assert len(args) == len(self.adapter.validators) # TODO: remove later + if clk: + self._freeze = True + if self._freeze: + continue + for arg, r in zip(args, (r for _, r in self.adapter.validators)): + sim.set(r, async_mock_def_helper(self, self.validate_arguments, arg)) + + async def effect_process(self, sim: SimulatorContext) -> None: + sim.set(self.adapter.en, self.enable()) + async for *_, done in sim.tick().sample(self.adapter.done): + # Disabling the method on each cycle forces an edge when it is reenabled again. + # The method body won't be executed until the effects are done. + sim.set(self.adapter.en, False) + + # First, perform pending effects, updating internal state. + with sim.critical(): + if done: + for eff in self._effects: + eff() + + # Ensure that the effects of all mocks are applied. Delay 0 also does this! + await sim.delay(self.delay) + + # Next, enable the method. The output will be updated by a combinational process. + self._effects = [] + self._freeze = False + sim.set(self.adapter.en, self.enable()) + + +def def_method_mock( + tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], **kwargs +) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], MethodMock]]: + """ + Decorator function to create method mock handlers. It should be applied on + a function which describes functionality which we want to invoke on method call. + This function will be called on every clock cycle when the method is active, + and also on combinational changes to inputs. + + The decorated function can have a single argument `arg`, which receives + the arguments passed to a method as a `data.Const`, or multiple named arguments, + which correspond to named arguments of the method. + + This decorator can be applied to function definitions or method definitions. + When applied to a method definition, lambdas passed to `def_method_mock` + need to take a `self` argument, which should be the first. + + Mocks defined at class level or at test level are automatically discovered and + don't need to be manually added to the simulation. + + Any side effects (state modification, assertions, etc.) need to be guarded + using the `MethodMock.effect` decorator. + + Make sure to defer accessing state, since decorators are evaluated eagerly + during function declaration. + + Parameters + ---------- + tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO] + Function to get the TestbenchIO of the mocked method. + enable : Callable[[], bool] | Callable[[Any], bool] + Function which decides if the method is enabled in a given clock cycle. + validate_arguments : Callable[..., bool] + Function which validates call arguments. This applies only to Adapters + with `with_validate_arguments` set to True. + delay : float + Simulation time delay for method mock calling. Used for synchronization + between different mocks and testbench processes. + + Example + ------- + ``` + @def_method_mock(lambda: m.target[k]) + def process(arg): + return {"data": arg["data"] + k} + ``` + or for class methods + ``` + @def_method_mock(lambda self: self.target[k]) + def process(self, data): + return {"data": data + k} + ``` + """ + + def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[], MethodMock]: + @functools.wraps(func) + def mock(func_self=None, /) -> MethodMock: + f = func + getter: Any = tb_getter + kw = kwargs + if func_self is not None: + getter = getter.__get__(func_self) + f = f.__get__(func_self) + kw = {} + for k, v in kwargs.items(): + bind = getattr(v, "__get__", None) + kw[k] = bind(func_self) if bind else v + tb = getter() + assert isinstance(tb, TestbenchIO) + assert isinstance(tb.adapter, Adapter) + return MethodMock(tb.adapter, f, **kw) + + mock._transactron_method_mock = 1 # type: ignore + return mock + + return decorator diff --git a/transactron/testing/profiler.py b/transactron/testing/profiler.py index 795c7f2..ace2b63 100644 --- a/transactron/testing/profiler.py +++ b/transactron/testing/profiler.py @@ -1,37 +1,46 @@ -from amaranth.sim import * +from amaranth import Cat +from amaranth.lib.data import StructLayout, View +from amaranth.sim._async import ProcessContext from transactron.core import TransactionManager from transactron.core.manager import MethodMap from transactron.profiler import CycleProfile, MethodSamples, Profile, ProfileData, ProfileSamples, TransactionSamples -from .functions import TestGen __all__ = ["profiler_process"] def profiler_process(transaction_manager: TransactionManager, profile: Profile): - def process() -> TestGen: + async def process(sim: ProcessContext) -> None: profile_data, get_id = ProfileData.make(transaction_manager) method_map = MethodMap(transaction_manager.transactions) profile.transactions_and_methods = profile_data.transactions_and_methods - yield Passive() - while True: - yield Tick("sync_neg") + transaction_sample_layout = StructLayout({"request": 1, "runnable": 1, "grant": 1}) + async for _, _, *data in ( + sim.tick() + .sample( + *( + View(transaction_sample_layout, Cat(transaction.request, transaction.runnable, transaction.grant)) + for transaction in method_map.transactions + ) + ) + .sample(*(method.run for method in method_map.methods)) + ): + transaction_data = data[: len(method_map.transactions)] + method_data = data[len(method_map.transactions) :] samples = ProfileSamples() - for transaction in method_map.transactions: + for transaction, tsample in zip(method_map.transactions, transaction_data): samples.transactions[get_id(transaction)] = TransactionSamples( - bool((yield transaction.request)), - bool((yield transaction.runnable)), - bool((yield transaction.grant)), + bool(tsample.request), + bool(tsample.runnable), + bool(tsample.grant), ) - for method in method_map.methods: - samples.methods[get_id(method)] = MethodSamples(bool((yield method.run))) + for method, run in zip(method_map.methods, method_data): + samples.methods[get_id(method)] = MethodSamples(bool(run)) cprof = CycleProfile.make(samples, profile_data) profile.cycles.append(cprof) - yield Tick() - return process diff --git a/transactron/testing/sugar.py b/transactron/testing/sugar.py deleted file mode 100644 index de1dc5e..0000000 --- a/transactron/testing/sugar.py +++ /dev/null @@ -1,82 +0,0 @@ -import functools -from typing import Callable, Any, Optional -from .testbenchio import TestbenchIO, TestGen -from transactron.utils._typing import RecordIntDict - - -def def_method_mock( - tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], sched_prio: int = 0, **kwargs -) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], TestGen[None]]]: - """ - Decorator function to create method mock handlers. It should be applied on - a function which describes functionality which we want to invoke on method call. - Such function will be wrapped by `method_handle_loop` and called on each - method invocation. - - Function `f` should take only one argument `arg` - data used in function - invocation - and should return data to be sent as response to the method call. - - Function `f` can also be a method and take two arguments `self` and `arg`, - the data to be passed on to invoke a method. It should return data to be sent - as response to the method call. - - Instead of the `arg` argument, the data can be split into keyword arguments. - - Make sure to defer accessing state, since decorators are evaluated eagerly - during function declaration. - - Parameters - ---------- - tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO] - Function to get the TestbenchIO providing appropriate `method_handle_loop`. - **kwargs - Arguments passed to `method_handle_loop`. - - Example - ------- - ``` - m = TestCircuit() - def target_process(k: int): - @def_method_mock(lambda: m.target[k]) - def process(arg): - return {"data": arg["data"] + k} - return process - ``` - or equivalently - ``` - m = TestCircuit() - def target_process(k: int): - @def_method_mock(lambda: m.target[k], settle=1, enable=False) - def process(data): - return {"data": data + k} - return process - ``` - or for class methods - ``` - @def_method_mock(lambda self: self.target[k], settle=1, enable=False) - def process(self, data): - return {"data": data + k} - ``` - """ - - def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[], TestGen[None]]: - @functools.wraps(func) - def mock(func_self=None, /) -> TestGen[None]: - f = func - getter: Any = tb_getter - kw = kwargs - if func_self is not None: - getter = getter.__get__(func_self) - f = f.__get__(func_self) - kw = {} - for k, v in kwargs.items(): - bind = getattr(v, "__get__", None) - kw[k] = bind(func_self) if bind else v - tb = getter() - assert isinstance(tb, TestbenchIO) - yield from tb.method_handle_loop(f, extra_settle_count=sched_prio, **kw) - - mock._transactron_testing_process = 1 # type: ignore - return mock - - return decorator diff --git a/transactron/testing/testbenchio.py b/transactron/testing/testbenchio.py index 9af6340..0553184 100644 --- a/transactron/testing/testbenchio.py +++ b/transactron/testing/testbenchio.py @@ -1,11 +1,130 @@ +from collections.abc import Generator, Iterable from amaranth import * -from amaranth.sim import Settle, Passive, Tick -from typing import Optional, Callable +from amaranth.lib.data import View, StructLayout +from amaranth.sim._async import SimulatorContext, TestbenchContext +from typing import Any, Optional + from transactron.lib import AdapterBase -from transactron.lib.adapters import Adapter -from transactron.utils import ValueLike, SignalBundle, mock_def_helper, assign -from transactron.utils._typing import RecordIntDictRet, RecordValueDict, RecordIntDict -from .functions import get_outputs, TestGen +from transactron.utils import ValueLike +from .functions import MethodData + + +__all__ = ["CallTrigger", "TestbenchIO"] + + +class CallTrigger: + """A trigger which allows to call multiple methods and sample signals. + + The `call()` and `call_try()` methods on a `TestbenchIO` always wait at least one clock cycle. It follows + that these methods can't be used to perform calls to multiple methods in a single clock cycle. Usually + this is not a problem, as different methods can be called from different simulation processes. But in cases + when more control over the time when different calls happen is needed, this trigger class allows to call + many methods in a single clock cycle. + """ + + def __init__( + self, + sim: SimulatorContext, + _calls: Iterable[ValueLike | tuple["TestbenchIO", Optional[dict[str, Any]]]] = (), + ): + """ + Parameters + ---------- + sim: SimulatorContext + Amaranth simulator context. + """ + self.sim = sim + self.calls_and_values: list[ValueLike | tuple[TestbenchIO, Optional[dict[str, Any]]]] = list(_calls) + + def sample(self, *values: "ValueLike | TestbenchIO"): + """Sample a signal or a method result on a clock edge. + + Values are sampled like in standard Amaranth `TickTrigger`. Sampling a method result works like `call()`, + but the method is not called - another process can do that instead. If the method was not called, the + sampled value is `None`. + + Parameters + ---------- + *values: ValueLike | TestbenchIO + Value or method to sample. + """ + new_calls_and_values: list[ValueLike | tuple["TestbenchIO", None]] = [] + for value in values: + if isinstance(value, TestbenchIO): + new_calls_and_values.append((value, None)) + else: + new_calls_and_values.append(value) + return CallTrigger(self.sim, (*self.calls_and_values, *new_calls_and_values)) + + def call(self, tbio: "TestbenchIO", data: dict[str, Any] = {}, /, **kwdata): + """Call a method and sample its result. + + Adds a method call to the trigger. The method result is sampled on a clock edge. If the call did not + succeed, the sampled value is `None`. + + Parameters + ---------- + tbio: TestbenchIO + The method to call. + data: dict[str, Any] + Method call arguments stored in a dict. + **kwdata: Any + Method call arguments passed as keyword arguments. If keyword arguments are used, + the `data` argument should not be provided. + """ + if data and kwdata: + raise TypeError("call() takes either a single dict or keyword arguments") + return CallTrigger(self.sim, (*self.calls_and_values, (tbio, data or kwdata))) + + async def until_done(self) -> Any: + """Wait until at least one of the calls succeeds. + + The `CallTrigger` normally acts like `TickTrigger`, e.g. awaiting on it advances the clock to the next + clock edge. It is possible that none of the calls could not be performed, for example because the called + methods were not enabled. In case we only want to focus on the cycles when one of the calls succeeded, + `until_done` can be used. This works like `until()` in `TickTrigger`. + """ + async for results in self: + if any(res is not None for res in results): + return results + + def __await__(self) -> Generator: + only_calls = [t for t in self.calls_and_values if isinstance(t, tuple)] + only_values = [t for t in self.calls_and_values if not isinstance(t, tuple)] + + for tbio, data in only_calls: + if data is not None: + tbio.call_init(self.sim, data) + + def layout_for(tbio: TestbenchIO): + return StructLayout({"outputs": tbio.adapter.data_out.shape(), "done": 1}) + + trigger = ( + self.sim.tick() + .sample(*(View(layout_for(tbio), Cat(tbio.outputs, tbio.done)) for tbio, _ in only_calls)) + .sample(*only_values) + ) + _, _, *results = yield from trigger.__await__() + + for tbio, data in only_calls: + if data is not None: + tbio.disable(self.sim) + + values_it = iter(results[len(only_calls) :]) + calls_it = (s.outputs if s.done else None for s in results[: len(only_calls)]) + + def ret(): + for v in self.calls_and_values: + if isinstance(v, tuple): + yield next(calls_it) + else: + yield next(values_it) + + return tuple(ret()) + + async def __aiter__(self): + while True: + yield await self class TestbenchIO(Elaboratable): @@ -19,134 +138,69 @@ def elaborate(self, platform): # Low-level operations - def set_enable(self, en) -> TestGen[None]: - yield self.adapter.en.eq(1 if en else 0) + def set_enable(self, sim: SimulatorContext, en): + sim.set(self.adapter.en, 1 if en else 0) - def enable(self) -> TestGen[None]: - yield from self.set_enable(True) + def enable(self, sim: SimulatorContext): + self.set_enable(sim, True) - def disable(self) -> TestGen[None]: - yield from self.set_enable(False) + def disable(self, sim: SimulatorContext): + self.set_enable(sim, False) - def done(self) -> TestGen[int]: - return (yield self.adapter.done) + @property + def done(self): + return self.adapter.done - def wait_until_done(self) -> TestGen[None]: - while (yield self.adapter.done) != 1: - yield Tick() + @property + def outputs(self): + return self.adapter.data_out - def set_inputs(self, data: RecordValueDict = {}) -> TestGen[None]: - yield from assign(self.adapter.data_in, data) + def set_inputs(self, sim: SimulatorContext, data): + sim.set(self.adapter.data_in, data) - def get_outputs(self) -> TestGen[RecordIntDictRet]: - return (yield from get_outputs(self.adapter.data_out)) + def get_done(self, sim: TestbenchContext): + return sim.get(self.adapter.done) - # Operations for AdapterTrans + def get_outputs(self, sim: TestbenchContext) -> MethodData: + return sim.get(self.adapter.data_out) - def call_init(self, data: RecordValueDict = {}, /, **kwdata: ValueLike | RecordValueDict) -> TestGen[None]: - if data and kwdata: - raise TypeError("call_init() takes either a single dict or keyword arguments") - if not data: - data = kwdata - yield from self.enable() - yield from self.set_inputs(data) + def sample_outputs(self, sim: SimulatorContext): + return sim.tick().sample(self.adapter.data_out) - def call_result(self) -> TestGen[Optional[RecordIntDictRet]]: - if (yield from self.done()): - return (yield from self.get_outputs()) - return None + def sample_outputs_until_done(self, sim: SimulatorContext): + return self.sample_outputs(sim).until(self.adapter.done) - def call_do(self) -> TestGen[RecordIntDict]: - while (outputs := (yield from self.call_result())) is None: - yield Tick() - yield from self.disable() - return outputs + def sample_outputs_done(self, sim: SimulatorContext): + return sim.tick().sample(self.adapter.data_out, self.adapter.done) - def call_try( - self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict - ) -> TestGen[Optional[RecordIntDictRet]]: - if data and kwdata: - raise TypeError("call_try() takes either a single dict or keyword arguments") - if not data: - data = kwdata - yield from self.call_init(data) - yield Tick() - outputs = yield from self.call_result() - yield from self.disable() - return outputs + # Operations for AdapterTrans - def call(self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict) -> TestGen[RecordIntDictRet]: + def call_init(self, sim: SimulatorContext, data={}, /, **kwdata): if data and kwdata: - raise TypeError("call() takes either a single dict or keyword arguments") + raise TypeError("call_init() takes either a single dict or keyword arguments") if not data: data = kwdata - yield from self.call_init(data) - yield Tick() - return (yield from self.call_do()) - - # Operations for Adapter + self.enable(sim) + self.set_inputs(sim, data) - def method_argument(self) -> TestGen[Optional[RecordIntDictRet]]: - return (yield from self.call_result()) + def get_call_result(self, sim: TestbenchContext) -> Optional[MethodData]: + if self.get_done(sim): + return self.get_outputs(sim) + return None - def method_return(self, data: RecordValueDict = {}) -> TestGen[None]: - yield from self.set_inputs(data) + async def call_result(self, sim: SimulatorContext) -> Optional[MethodData]: + *_, data, done = await self.sample_outputs_done(sim) + if done: + return data + return None - def method_handle( - self, - function: Callable[..., Optional[RecordIntDict]], - *, - enable: Optional[Callable[[], bool]] = None, - validate_arguments: Optional[Callable[..., bool]] = None, - extra_settle_count: int = 0, - ) -> TestGen[None]: - enable = enable or (lambda: True) - yield from self.set_enable(enable()) - - def handle_validate_arguments(): - if validate_arguments is not None: - assert isinstance(self.adapter, Adapter) - for a, r in self.adapter.validators: - ret_out = mock_def_helper(self, validate_arguments, (yield from get_outputs(a))) - yield r.eq(ret_out) - for _ in range(extra_settle_count + 1): - yield Settle() - - # One extra Settle() required to propagate enable signal. - for _ in range(extra_settle_count + 1): - yield Settle() - yield from handle_validate_arguments() - while (arg := (yield from self.method_argument())) is None: - yield Tick() - - yield from self.set_enable(enable()) - for _ in range(extra_settle_count + 1): - yield Settle() - yield from handle_validate_arguments() - - ret_out = mock_def_helper(self, function, arg) - yield from self.method_return(ret_out or {}) - yield Tick() - yield from self.set_enable(False) - - def method_handle_loop( - self, - function: Callable[..., Optional[RecordIntDict]], - *, - enable: Optional[Callable[[], bool]] = None, - validate_arguments: Optional[Callable[..., bool]] = None, - extra_settle_count: int = 0, - ) -> TestGen[None]: - yield Passive() - while True: - yield from self.method_handle( - function, - enable=enable, - validate_arguments=validate_arguments, - extra_settle_count=extra_settle_count, - ) + async def call_do(self, sim: SimulatorContext) -> MethodData: + *_, outputs = await self.sample_outputs_until_done(sim) + self.disable(sim) + return outputs - # Debug signals + async def call_try(self, sim: SimulatorContext, data={}, /, **kwdata) -> Optional[MethodData]: + return (await CallTrigger(sim).call(self, data, **kwdata))[0] - def debug_signals(self) -> SignalBundle: - return self.adapter.debug_signals() + async def call(self, sim: SimulatorContext, data={}, /, **kwdata) -> MethodData: + return (await CallTrigger(sim).call(self, data, **kwdata).until_done())[0] diff --git a/transactron/testing/tick_count.py b/transactron/testing/tick_count.py new file mode 100644 index 0000000..a2d3828 --- /dev/null +++ b/transactron/testing/tick_count.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass + +from amaranth import Signal +from amaranth.sim._async import ProcessContext + +from transactron.utils.dependencies import DependencyContext, SimpleKey + + +__all__ = ["TicksKey", "make_tick_count_process"] + + +@dataclass(frozen=True) +class TicksKey(SimpleKey[Signal]): + pass + + +def make_tick_count_process(): + ticks = Signal(64) + DependencyContext.get().add_dependency(TicksKey(), ticks) + + async def process(sim: ProcessContext): + async for _, _, ticks_val in sim.tick().sample(ticks): + sim.set(ticks, ticks_val + 1) + + return process diff --git a/transactron/utils/transactron_helpers.py b/transactron/utils/transactron_helpers.py index 1fae882..9cb23cd 100644 --- a/transactron/utils/transactron_helpers.py +++ b/transactron/utils/transactron_helpers.py @@ -8,6 +8,7 @@ from amaranth import * from amaranth import tracer from amaranth.lib.data import StructLayout +import amaranth.lib.data as data __all__ = [ @@ -17,6 +18,7 @@ "def_helper", "method_def_helper", "mock_def_helper", + "async_mock_def_helper", "get_src_loc", "from_method_layout", "make_layout", @@ -103,6 +105,13 @@ def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T: return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg) +def async_mock_def_helper(tb, func: Callable[..., T], arg: "data.Const[StructLayout]") -> T: + marg = {} + for k, _ in arg.shape(): + marg[k] = arg[k] + return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], marg, **marg) + + def method_def_helper(method, func: Callable[..., T], arg: MethodStruct) -> T: kwargs = {k: arg[k] for k in arg.shape().members} return def_helper(f"method definition for {method}", func, MethodStruct, arg, **kwargs)