diff --git a/coreblocks/frontend/fetch.py b/coreblocks/frontend/fetch.py index 38062526f..1a179b430 100644 --- a/coreblocks/frontend/fetch.py +++ b/coreblocks/frontend/fetch.py @@ -151,7 +151,7 @@ def elaborate(self, platform) -> TModule: m.d.av_comb += stalled.eq(stalled_unsafe | stalled_exception) with Transaction().body(m, request=~stalled): - aligned_pc = Cat(Repl(0, 2), cache_req_pc[2:]) + aligned_pc = Cat(C(0, 2), cache_req_pc[2:]) self.icache.issue_req(m, addr=aligned_pc) req_limiter.acquire(m) diff --git a/coreblocks/frontend/icache.py b/coreblocks/frontend/icache.py index 576ea1220..16d4462db 100644 --- a/coreblocks/frontend/icache.py +++ b/coreblocks/frontend/icache.py @@ -88,7 +88,7 @@ def _(addr: Value) -> None: addr=addr >> log2_int(self.params.word_width_bytes), data=0, we=0, - sel=Repl(1, self.wb_master.wb_params.data_width // self.wb_master.wb_params.granularity), + sel=C(1).replicate(self.wb_master.wb_params.data_width // self.wb_master.wb_params.granularity), ) @def_method(m, self.accept_res) @@ -275,7 +275,7 @@ def _() -> None: with m.If(fsm.ongoing("FLUSH")): m.d.comb += [ - self.mem.way_wr_en.eq(Repl(1, self.params.num_of_ways)), + self.mem.way_wr_en.eq(C(1).replicate(self.params.num_of_ways)), self.mem.tag_wr_index.eq(flush_index), self.mem.tag_wr_data.valid.eq(0), self.mem.tag_wr_data.tag.eq(0), @@ -393,7 +393,7 @@ def elaborate(self, platform): addr=Cat(address["word_counter"], address["refill_address"]), data=0, we=0, - sel=Repl(1, self.wb_master.wb_params.data_width // self.wb_master.wb_params.granularity), + sel=C(1).replicate(self.wb_master.wb_params.data_width // self.wb_master.wb_params.granularity), ) @def_method(m, self.start_refill, ready=~refill_active) @@ -421,7 +421,7 @@ def _(): address_fwd.write(m, word_counter=next_word_counter, refill_address=refill_address) return { - "addr": Cat(Repl(0, log2_int(self.params.word_width_bytes)), word_counter, refill_address), + "addr": Cat(C(0, log2_int(self.params.word_width_bytes)), word_counter, refill_address), "data": fetched.data, "error": fetched.err, "last": last, diff --git a/coreblocks/frontend/rvc.py b/coreblocks/frontend/rvc.py index 7da91dc9b..bd01255fd 100644 --- a/coreblocks/frontend/rvc.py +++ b/coreblocks/frontend/rvc.py @@ -47,10 +47,10 @@ def _quadrant_0(self) -> list[DecodedInstr]: rd = self.decompr_reg(self.instr_in[2:5]) addi4spn_imm = Cat( - Repl(0, 2), self.instr_in[6], self.instr_in[5], self.instr_in[11:13], self.instr_in[7:11], Repl(0, 2) + C(0, 2), self.instr_in[6], self.instr_in[5], self.instr_in[11:13], self.instr_in[7:11], C(0, 2) ) - lsd_imm = Cat(Repl(0, 3), self.instr_in[10:13], self.instr_in[5:7], Repl(0, 4)) - lsw_imm = Cat(Repl(0, 2), self.instr_in[6], self.instr_in[10:13], self.instr_in[5], Repl(0, 5)) + lsd_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[5:7], C(0, 4)) + lsw_imm = Cat(C(0, 2), self.instr_in[6], self.instr_in[10:13], self.instr_in[5], C(0, 5)) addi4spn = ( ITypeInstr(opcode=Opcode.OP_IMM, rd=rd, funct3=Funct3.ADD, rs1=Registers.SP, imm=addi4spn_imm), @@ -98,18 +98,18 @@ def _quadrant_1(self) -> list[DecodedInstr]: rs2 = self.decompr_reg(self.instr_in[2:5]) rd = self.instr_in[7:12] - addi_imm = Cat(self.instr_in[2:7], Repl(self.instr_in[12], 7)) + addi_imm = Cat(self.instr_in[2:7], self.instr_in[12].replicate(7)) addi16sp_imm = Cat( - Repl(0, 4), + C(0, 4), self.instr_in[6], self.instr_in[2], self.instr_in[5], self.instr_in[3:5], - Repl(self.instr_in[12], 3), + self.instr_in[12].replicate(3), ) - lui_imm = Cat(Repl(0, 12), self.instr_in[2:7], Repl(self.instr_in[12], 15)) + lui_imm = Cat(C(0, 12), self.instr_in[2:7], self.instr_in[12].replicate(15)) j_imm = Cat( - Repl(0, 1), + C(0, 1), self.instr_in[3:6], self.instr_in[11], self.instr_in[2], @@ -117,15 +117,15 @@ def _quadrant_1(self) -> list[DecodedInstr]: self.instr_in[6], self.instr_in[9:11], self.instr_in[8], - Repl(self.instr_in[12], 10), + self.instr_in[12].replicate(10), ) b_imm = Cat( - Repl(0, 1), + C(0, 1), self.instr_in[3:5], self.instr_in[10:12], self.instr_in[2], self.instr_in[5:7], - Repl(self.instr_in[12], 5), + self.instr_in[12].replicate(5), ) shamt = Cat(self.instr_in[2:7], self.instr_in[12]) @@ -206,10 +206,10 @@ def _quadrant_2(self) -> list[DecodedInstr]: rs2 = self.instr_in[2:7] shamt = Cat(self.instr_in[2:7], self.instr_in[12]) - ldsp_imm = Cat(Repl(0, 3), self.instr_in[5:7], self.instr_in[12], self.instr_in[2:5], Repl(0, 3)) - lwsp_imm = Cat(Repl(0, 2), self.instr_in[4:7], self.instr_in[12], self.instr_in[2:4], Repl(0, 4)) - sdsp_imm = Cat(Repl(0, 3), self.instr_in[10:13], self.instr_in[7:10], Repl(0, 2)) - swsp_imm = Cat(Repl(0, 2), self.instr_in[9:13], self.instr_in[7:9], Repl(0, 4)) + ldsp_imm = Cat(C(0, 3), self.instr_in[5:7], self.instr_in[12], self.instr_in[2:5], C(0, 3)) + lwsp_imm = Cat(C(0, 2), self.instr_in[4:7], self.instr_in[12], self.instr_in[2:4], C(0, 4)) + sdsp_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[7:10], C(0, 2)) + swsp_imm = Cat(C(0, 2), self.instr_in[9:13], self.instr_in[7:9], C(0, 4)) slli = ( RTypeInstr( @@ -249,10 +249,10 @@ def _quadrant_2(self) -> list[DecodedInstr]: sdsp = STypeInstr(opcode=Opcode.STORE, imm=sdsp_imm, funct3=Funct3.D, rs1=Registers.SP, rs2=rs2) jr = ( - ITypeInstr(opcode=Opcode.JALR, rd=Registers.ZERO, funct3=Funct3.JALR, rs1=rd_rs1, imm=Repl(0, 12)), + ITypeInstr(opcode=Opcode.JALR, rd=Registers.ZERO, funct3=Funct3.JALR, rs1=rd_rs1, imm=C(0, 12)), rd_rs1.any(), ) - jalr = ITypeInstr(opcode=Opcode.JALR, rd=Registers.RA, funct3=Funct3.JALR, rs1=rd_rs1, imm=Repl(0, 12)) + jalr = ITypeInstr(opcode=Opcode.JALR, rd=Registers.RA, funct3=Funct3.JALR, rs1=rd_rs1, imm=C(0, 12)) ebreak = EBreakInstr() diff --git a/coreblocks/fu/alu.py b/coreblocks/fu/alu.py index 7714ab35e..bc5bf72b5 100644 --- a/coreblocks/fu/alu.py +++ b/coreblocks/fu/alu.py @@ -186,15 +186,15 @@ def elaborate(self, platform): m.d.comb += clz.in_sig.eq(self.in1[::-1]) m.d.comb += self.out.eq(clz.out_sig) with OneHotCase(AluFn.Fn.SEXTH): - m.d.comb += self.out.eq(Cat(self.in1[0:16], Repl(self.in1[15], xlen - 16))) + m.d.comb += self.out.eq(Cat(self.in1[0:16], self.in1[15].replicate(xlen - 16))) with OneHotCase(AluFn.Fn.SEXTB): - m.d.comb += self.out.eq(Cat(self.in1[0:8], Repl(self.in1[7], xlen - 8))) + m.d.comb += self.out.eq(Cat(self.in1[0:8], self.in1[7].replicate(xlen - 8))) with OneHotCase(AluFn.Fn.ZEXTH): m.d.comb += self.out.eq(Cat(self.in1[0:16], C(0, shape=unsigned(xlen - 16)))) with OneHotCase(AluFn.Fn.ORCB): def _or(s: Value) -> Value: - return Repl(s.any(), 8) + return s.any().replicate(8) for i in range(xlen // 8): m.d.comb += self.out[i * 8 : (i + 1) * 8].eq(_or(self.in1[i * 8 : (i + 1) * 8])) diff --git a/coreblocks/fu/shift_unit.py b/coreblocks/fu/shift_unit.py index 7e73ae2f4..f0ce3dc2d 100644 --- a/coreblocks/fu/shift_unit.py +++ b/coreblocks/fu/shift_unit.py @@ -62,7 +62,7 @@ def elaborate(self, platform): with OneHotCase(ShiftUnitFn.Fn.SRL): m.d.comb += self.out.eq(self.in1 >> self.in2[0:xlen_log]) with OneHotCase(ShiftUnitFn.Fn.SRA): - m.d.comb += self.out.eq(Cat(self.in1, Repl(self.in1[xlen - 1], xlen)) >> self.in2[0:xlen_log]) + m.d.comb += self.out.eq(Cat(self.in1, self.in1[xlen - 1].replicate(xlen)) >> self.in2[0:xlen_log]) if self.zbb_enable: with OneHotCase(ShiftUnitFn.Fn.ROL): diff --git a/coreblocks/params/instr.py b/coreblocks/params/instr.py index ae207b1dd..7bf830436 100644 --- a/coreblocks/params/instr.py +++ b/coreblocks/params/instr.py @@ -50,7 +50,7 @@ def __init__( self.funct7 = Value.cast(funct7) def pack(self) -> Value: - return Cat(Repl(1, 2), self.opcode, self.rd, self.funct3, self.rs1, self.rs2, self.funct7) + return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.rs2, self.funct7) class ITypeInstr(RISCVInstr): @@ -62,7 +62,7 @@ def __init__(self, opcode: ValueLike, rd: ValueLike, funct3: ValueLike, rs1: Val self.imm = Value.cast(imm) def pack(self) -> Value: - return Cat(Repl(1, 2), self.opcode, self.rd, self.funct3, self.rs1, self.imm) + return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.imm) class STypeInstr(RISCVInstr): @@ -74,7 +74,7 @@ def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: Va self.rs2 = Value.cast(rs2) def pack(self) -> Value: - return Cat(Repl(1, 2), self.opcode, self.imm[0:5], self.funct3, self.rs1, self.rs2, self.imm[5:12]) + return Cat(C(0b11, 2), self.opcode, self.imm[0:5], self.funct3, self.rs1, self.rs2, self.imm[5:12]) class BTypeInstr(RISCVInstr): @@ -87,7 +87,7 @@ def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: Va def pack(self) -> Value: return Cat( - Repl(1, 2), + C(0b11, 2), self.opcode, self.imm[11], self.imm[1:5], @@ -106,7 +106,7 @@ def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike): self.imm = Value.cast(imm) def pack(self) -> Value: - return Cat(Repl(1, 2), self.opcode, self.rd, self.imm[12:]) + return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:]) class JTypeInstr(RISCVInstr): @@ -116,7 +116,7 @@ def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike): self.imm = Value.cast(imm) def pack(self) -> Value: - return Cat(Repl(1, 2), self.opcode, self.rd, self.imm[12:20], self.imm[11], self.imm[1:11], self.imm[20]) + return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:20], self.imm[11], self.imm[1:11], self.imm[20]) class IllegalInstr(RISCVInstr): @@ -124,7 +124,7 @@ def __init__(self): pass def pack(self) -> Value: - return Repl(1, 32) # Instructions with all bits set to 1 are reserved to be illegal. + return C(1).replicate(32) # Instructions with all bits set to 1 are reserved to be illegal. class EBreakInstr(ITypeInstr): diff --git a/coreblocks/peripherals/wishbone.py b/coreblocks/peripherals/wishbone.py index e6a9bc2ff..bbf81eb18 100644 --- a/coreblocks/peripherals/wishbone.py +++ b/coreblocks/peripherals/wishbone.py @@ -1,6 +1,5 @@ from amaranth import * from amaranth.hdl.rec import DIR_FANIN, DIR_FANOUT -from amaranth.lib.scheduler import RoundRobin from functools import reduce from typing import List import operator @@ -8,7 +7,7 @@ from transactron import Method, def_method, TModule from transactron.core import Transaction from transactron.lib import AdapterTrans, BasicFifo -from transactron.utils import OneHotSwitchDynamic, assign +from transactron.utils import OneHotSwitchDynamic, assign, RoundRobin from transactron.lib.connectors import Forwarder diff --git a/scripts/core_graph.py b/scripts/core_graph.py index b12f1a44b..a589c205a 100755 --- a/scripts/core_graph.py +++ b/scripts/core_graph.py @@ -15,12 +15,12 @@ from coreblocks.params.genparams import GenParams # noqa: E402 from transactron.graph import TracingFragment # noqa: E402 -from test.test_core import TestElaboratable # noqa: E402 +from test.test_core import CoreTestElaboratable # noqa: E402 from coreblocks.params.configurations import basic_core_config # noqa: E402 from transactron.core import TransactionModule # noqa: E402 gp = GenParams(basic_core_config) -elaboratable = TestElaboratable(gp) +elaboratable = CoreTestElaboratable(gp) tm = TransactionModule(elaboratable) fragment = TracingFragment.get(tm, platform=None).prepare() diff --git a/stubs/amaranth/hdl/ast.pyi b/stubs/amaranth/hdl/ast.pyi index b22901c29..fa115b316 100644 --- a/stubs/amaranth/hdl/ast.pyi +++ b/stubs/amaranth/hdl/ast.pyi @@ -249,6 +249,26 @@ class Value(metaclass=ABCMeta): """Rotate right by constant amount.""" ... + def replicate(self, count : int) -> Value: + """Replication. + + A ``Value`` is replicated (repeated) several times to be used + on the RHS of assignments:: + + len(v.replicate(n)) == len(v) * n + + Parameters + ---------- + count : int + Number of replications. + + Returns + ------- + Value, out + Replicated value. + """ + ... + def eq(self, value: ValueLike) -> Assign: """Assignment. diff --git a/test/common/infrastructure.py b/test/common/infrastructure.py index 058d5b9ed..d3903738f 100644 --- a/test/common/infrastructure.py +++ b/test/common/infrastructure.py @@ -84,7 +84,7 @@ def debug_signals(self): return sigs -class TestModule(Elaboratable): +class _TestModule(Elaboratable): def __init__(self, tested_module: HasElaborate, add_transaction_module): self.tested_module = TransactionModule(tested_module) if add_transaction_module else tested_module self.add_transaction_module = add_transaction_module @@ -137,7 +137,7 @@ def _wrapping_function(self): class PysimSimulator(Simulator): def __init__(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True, traces_file=None): - test_module = TestModule(module, add_transaction_module) + test_module = _TestModule(module, add_transaction_module) tested_module = test_module.tested_module super().__init__(test_module) diff --git a/test/frontend/test_decode.py b/test/frontend/test_decode.py index 04ed07c71..8ae4001cc 100644 --- a/test/frontend/test_decode.py +++ b/test/frontend/test_decode.py @@ -1,47 +1,37 @@ -from amaranth import Elaboratable, Module - from transactron.lib import AdapterTrans, FIFO -from ..common import TestCaseWithSimulator, TestbenchIO +from ..common import TestCaseWithSimulator, TestbenchIO, SimpleTestCircuit, ModuleConnector from coreblocks.frontend.decode import Decode from coreblocks.params import GenParams, FetchLayouts, DecodeLayouts, OpType, Funct3, Funct7 from coreblocks.params.configurations import test_core_config -class TestElaboratable(Elaboratable): - def __init__(self, gen_params: GenParams): - self.gen_params = gen_params - - def elaborate(self, platform): - m = Module() +class TestDecode(TestCaseWithSimulator): + def setUp(self) -> None: + self.gen_params = GenParams(test_core_config.replace(start_pc=24)) fifo_in = FIFO(self.gen_params.get(FetchLayouts).raw_instr, depth=2) fifo_out = FIFO(self.gen_params.get(DecodeLayouts).decoded_instr, depth=2) - self.io_in = TestbenchIO(AdapterTrans(fifo_in.write)) - self.io_out = TestbenchIO(AdapterTrans(fifo_out.read)) + self.fifo_in_write = TestbenchIO(AdapterTrans(fifo_in.write)) + self.fifo_out_read = TestbenchIO(AdapterTrans(fifo_out.read)) self.decode = Decode(self.gen_params, fifo_in.read, fifo_out.write) - - m.submodules.decode = self.decode - m.submodules.io_in = self.io_in - m.submodules.io_out = self.io_out - m.submodules.fifo_in = fifo_in - m.submodules.fifo_out = fifo_out - - return m - - -class TestFetch(TestCaseWithSimulator): - def setUp(self) -> None: - self.gen_params = GenParams(test_core_config.replace(start_pc=24)) - self.test_module = TestElaboratable(self.gen_params) + self.m = SimpleTestCircuit( + ModuleConnector( + decode=self.decode, + fifo_in=fifo_in, + fifo_out=fifo_out, + fifo_in_write=self.fifo_in_write, + fifo_out_read=self.fifo_out_read, + ) + ) def decode_test_proc(self): # testing an OP_IMM instruction (test copied from test_decoder.py) - yield from self.test_module.io_in.call(instr=0x02A28213) - decoded = yield from self.test_module.io_out.call() + yield from self.fifo_in_write.call(instr=0x02A28213) + decoded = yield from self.fifo_out_read.call() self.assertEqual(decoded["exec_fn"]["op_type"], OpType.ARITHMETIC) self.assertEqual(decoded["exec_fn"]["funct3"], Funct3.ADD) @@ -52,8 +42,8 @@ def decode_test_proc(self): self.assertEqual(decoded["imm"], 42) # testing an OP instruction (test copied from test_decoder.py) - yield from self.test_module.io_in.call(instr=0x003100B3) - decoded = yield from self.test_module.io_out.call() + yield from self.fifo_in_write.call(instr=0x003100B3) + decoded = yield from self.fifo_out_read.call() self.assertEqual(decoded["exec_fn"]["op_type"], OpType.ARITHMETIC) self.assertEqual(decoded["exec_fn"]["funct3"], Funct3.ADD) @@ -63,8 +53,8 @@ def decode_test_proc(self): self.assertEqual(decoded["regs_l"]["rl_s2"], 3) # testing an illegal - yield from self.test_module.io_in.call(instr=0x0) - decoded = yield from self.test_module.io_out.call() + yield from self.fifo_in_write.call(instr=0x0) + decoded = yield from self.fifo_out_read.call() self.assertEqual(decoded["exec_fn"]["op_type"], OpType.EXCEPTION) self.assertEqual(decoded["exec_fn"]["funct3"], Funct3._EILLEGALINSTR) @@ -73,8 +63,8 @@ def decode_test_proc(self): self.assertEqual(decoded["regs_l"]["rl_s1"], 0) self.assertEqual(decoded["regs_l"]["rl_s2"], 0) - yield from self.test_module.io_in.call(instr=0x0, access_fault=1) - decoded = yield from self.test_module.io_out.call() + yield from self.fifo_in_write.call(instr=0x0, access_fault=1) + decoded = yield from self.fifo_out_read.call() self.assertEqual(decoded["exec_fn"]["op_type"], OpType.EXCEPTION) self.assertEqual(decoded["exec_fn"]["funct3"], Funct3._EINSTRACCESSFAULT) @@ -84,5 +74,5 @@ def decode_test_proc(self): self.assertEqual(decoded["regs_l"]["rl_s2"], 0) def test(self): - with self.run_simulation(self.test_module) as sim: + with self.run_simulation(self.m) as sim: sim.add_sync_process(self.decode_test_proc) diff --git a/test/frontend/test_fetch.py b/test/frontend/test_fetch.py index 782ec6866..734309d4b 100644 --- a/test/frontend/test_fetch.py +++ b/test/frontend/test_fetch.py @@ -12,7 +12,7 @@ from coreblocks.params import * from coreblocks.params.configurations import test_core_config from transactron.utils import ModuleConnector -from ..common import TestCaseWithSimulator, TestbenchIO, def_method_mock +from ..common import TestCaseWithSimulator, TestbenchIO, def_method_mock, SimpleTestCircuit class MockedICache(Elaboratable, ICacheInterface): @@ -35,33 +35,25 @@ def elaborate(self, platform): return m -class TestElaboratable(Elaboratable): - def __init__(self, gen_params: GenParams): - self.gen_params = gen_params - - def elaborate(self, platform): - m = Module() +class TestFetch(TestCaseWithSimulator): + def setUp(self) -> None: + self.gen_params = GenParams(test_core_config.replace(start_pc=0x18)) self.icache = MockedICache(self.gen_params) fifo = FIFO(self.gen_params.get(FetchLayouts).raw_instr, depth=2) self.io_out = TestbenchIO(AdapterTrans(fifo.read)) - self.fetch = Fetch(self.gen_params, self.icache, fifo.write) - self.verify_branch = TestbenchIO(AdapterTrans(self.fetch.verify_branch)) - - m.submodules.icache = self.icache - m.submodules.fetch = self.fetch - m.submodules.io_out = self.io_out - m.submodules.verify_branch = self.verify_branch - m.submodules.fifo = fifo - - return m + self.fetch = SimpleTestCircuit(Fetch(self.gen_params, self.icache, fifo.write)) + self.verify_branch = TestbenchIO(AdapterTrans(self.fetch._dut.verify_branch)) + self.m = ModuleConnector( + icache=self.icache, + fetch=self.fetch, + io_out=self.io_out, + verify_branch=self.verify_branch, + fifo=fifo, + ) -class TestFetch(TestCaseWithSimulator): - def setUp(self) -> None: - self.gen_params = GenParams(test_core_config.replace(start_pc=0x18)) - self.m = TestElaboratable(self.gen_params) self.instr_queue = deque() self.iterations = 500 @@ -112,11 +104,11 @@ def cache_process(): } ) - @def_method_mock(lambda: self.m.icache.issue_req_io, enable=lambda: len(input_q) < 2, sched_prio=1) + @def_method_mock(lambda: self.icache.issue_req_io, enable=lambda: len(input_q) < 2, sched_prio=1) def issue_req_mock(addr): input_q.append(addr) - @def_method_mock(lambda: self.m.icache.accept_res_io, enable=lambda: len(output_q) > 0) + @def_method_mock(lambda: self.icache.accept_res_io, enable=lambda: len(output_q) > 0) def accept_res_mock(): return output_q.popleft() @@ -130,9 +122,9 @@ def fetch_out_check(self): instr = self.instr_queue.popleft() if instr["is_branch"]: yield from self.random_wait(10) - yield from self.m.verify_branch.call(from_pc=instr["pc"], next_pc=instr["next_pc"]) + yield from self.verify_branch.call(from_pc=instr["pc"], next_pc=instr["next_pc"]) - v = yield from self.m.io_out.call() + v = yield from self.io_out.call() self.assertEqual(v["pc"], instr["pc"]) self.assertEqual(v["instr"], instr["instr"]) diff --git a/test/frontend/test_rvc.py b/test/frontend/test_rvc.py index 7375a2d15..668ead899 100644 --- a/test/frontend/test_rvc.py +++ b/test/frontend/test_rvc.py @@ -136,7 +136,12 @@ # c.lwsp x2, 4 (0x4112, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X2, funct3=Funct3.W, rs1=Registers.SP, imm=C(4, 12))), # c.jr x30 - (0x8F02, ITypeInstr(opcode=Opcode.JALR, rd=Registers.ZERO, funct3=Funct3.JALR, rs1=Registers.X30, imm=Repl(0, 12))), + ( + 0x8F02, + ITypeInstr( + opcode=Opcode.JALR, rd=Registers.ZERO, funct3=Funct3.JALR, rs1=Registers.X30, imm=C(0).replicate(12) + ), + ), # c.mv x2, x26 ( 0x816A, diff --git a/test/stages/test_retirement.py b/test/stages/test_retirement.py index f1a2f998f..a32c57dab 100644 --- a/test/stages/test_retirement.py +++ b/test/stages/test_retirement.py @@ -109,55 +109,54 @@ def setUp(self): # (and the retirement code doesn't have any special behaviour to handle these cases), but in this simple # test we don't care to make sure that the randomly generated inputs are correct in this way. + @def_method_mock(lambda self: self.retc.mock_rob_retire, enable=lambda self: bool(self.submit_q), sched_prio=1) + def retire_process(self): + return self.submit_q.popleft() + + # TODO: mocking really seems to dislike nonexclusive methods for some reason + @def_method_mock(lambda self: self.retc.mock_rob_peek, enable=lambda self: bool(self.submit_q)) + def peek_process(self): + return self.submit_q[0] + + def free_reg_process(self): + while self.rf_exp_q: + reg = yield from self.retc.free_rf_adapter.call() + self.assertEqual(reg["reg_id"], self.rf_exp_q.popleft()) + + def rat_process(self): + while self.rat_map_q: + current_map = self.rat_map_q.popleft() + wait_cycles = 0 + # this test waits for next rat pair to be correctly set and will timeout if that assignment fails + while (yield self.retc.rat.entries[current_map["rl_dst"]]) != current_map["rp_dst"]: + wait_cycles += 1 + if wait_cycles >= self.cycles + 10: + self.fail("RAT entry was not updated") + yield + self.assertFalse(self.submit_q) + self.assertFalse(self.rf_free_q) + + @def_method_mock(lambda self: self.retc.mock_rf_free, sched_prio=2) + def rf_free_process(self, reg_id): + self.assertEqual(reg_id, self.rf_free_q.popleft()) + + @def_method_mock(lambda self: self.retc.mock_precommit, sched_prio=2) + def precommit_process(self, rob_id, side_fx): + self.assertEqual(rob_id, self.precommit_q.popleft()) + + @def_method_mock(lambda self: self.retc.mock_exception_cause) + def exception_cause_process(self): + return {"cause": 0, "rob_id": 0} # keep exception cause method enabled + def test_rand(self): - retc = RetirementTestCircuit(self.gen_params) - - yield from retc.mock_fetch_stall.enable() - - @def_method_mock(lambda: retc.mock_rob_retire, enable=lambda: bool(self.submit_q), sched_prio=1) - def retire_process(): - return self.submit_q.popleft() - - # TODO: mocking really seems to dislike nonexclusive methods for some reason - @def_method_mock(lambda: retc.mock_rob_peek, enable=lambda: bool(self.submit_q)) - def peek_process(): - return self.submit_q[0] - - def free_reg_process(): - while self.rf_exp_q: - reg = yield from retc.free_rf_adapter.call() - self.assertEqual(reg["reg_id"], self.rf_exp_q.popleft()) - - def rat_process(): - while self.rat_map_q: - current_map = self.rat_map_q.popleft() - wait_cycles = 0 - # this test waits for next rat pair to be correctly set and will timeout if that assignment fails - while (yield retc.rat.entries[current_map["rl_dst"]]) != current_map["rp_dst"]: - wait_cycles += 1 - if wait_cycles >= self.cycles + 10: - self.fail("RAT entry was not updated") - yield - self.assertFalse(self.submit_q) - self.assertFalse(self.rf_free_q) - - @def_method_mock(lambda: retc.mock_rf_free, sched_prio=2) - def rf_free_process(reg_id): - self.assertEqual(reg_id, self.rf_free_q.popleft()) - - @def_method_mock(lambda: retc.mock_precommit, sched_prio=2) - def precommit_process(rob_id, side_fx): - self.assertEqual(rob_id, self.precommit_q.popleft()) - - @def_method_mock(lambda: retc.mock_exception_cause) - def exception_cause_process(): - return {"cause": 0, "rob_id": 0} # keep exception cause method enabled - - with self.run_simulation(retc) as sim: - sim.add_sync_process(retire_process) - sim.add_sync_process(peek_process) - sim.add_sync_process(free_reg_process) - sim.add_sync_process(rat_process) - sim.add_sync_process(rf_free_process) - sim.add_sync_process(precommit_process) - sim.add_sync_process(exception_cause_process) + self.retc = RetirementTestCircuit(self.gen_params) + + yield from self.retc.mock_fetch_stall.enable() # To be fixed + with self.run_simulation(self.retc) as sim: + sim.add_sync_process(self.retire_process) + sim.add_sync_process(self.peek_process) + sim.add_sync_process(self.free_reg_process) + sim.add_sync_process(self.rat_process) + sim.add_sync_process(self.rf_free_process) + sim.add_sync_process(self.precommit_process) + sim.add_sync_process(self.exception_cause_process) diff --git a/test/structs_common/test_rs.py b/test/structs_common/test_rs.py index c35630b2c..aa5fe67ea 100644 --- a/test/structs_common/test_rs.py +++ b/test/structs_common/test_rs.py @@ -1,10 +1,6 @@ -from typing import Iterable, Optional -from amaranth import Elaboratable, Module from amaranth.sim import Settle -from transactron.lib import AdapterTrans - -from ..common import TestCaseWithSimulator, TestbenchIO, get_outputs +from ..common import TestCaseWithSimulator, get_outputs, SimpleTestCircuit from coreblocks.structs_common.rs import RS from coreblocks.params import * @@ -26,40 +22,11 @@ def create_check_list(rs_entries_bits: int, insert_list: list[dict]) -> list[dic return check_list -class TestElaboratable(Elaboratable): - def __init__(self, gen_params: GenParams, ready_for: Optional[Iterable[Iterable[OpType]]] = None) -> None: - self.gen_params = gen_params - self.ready_for = ready_for - # test config GenParams specifies only one RS - it has the max number of entries - self.rs_entries = self.gen_params.max_rs_entries - self.rs_entries_bits = self.gen_params.max_rs_entries_bits - - def elaborate(self, platform) -> Module: - m = Module() - rs = RS(self.gen_params, 2**self.rs_entries_bits, self.ready_for) - - self.rs = rs - self.io_select = TestbenchIO(AdapterTrans(rs.select)) - self.io_insert = TestbenchIO(AdapterTrans(rs.insert)) - self.io_update = TestbenchIO(AdapterTrans(rs.update)) - self.io_take = TestbenchIO(AdapterTrans(rs.take)) - self.io_get_ready_list = [TestbenchIO(AdapterTrans(get_ready_list)) for get_ready_list in rs.get_ready_list] - - m.submodules.rs = rs - m.submodules.io_select = self.io_select - m.submodules.io_insert = self.io_insert - m.submodules.io_update = self.io_update - m.submodules.io_take = self.io_take - for n, io_get_ready_list in enumerate(self.io_get_ready_list): - m.submodules[f"io_get_ready_list_{n}"] = io_get_ready_list - - return m - - class TestRSMethodInsert(TestCaseWithSimulator): def test_insert(self): self.gen_params = GenParams(test_core_config) - self.m = TestElaboratable(self.gen_params) + self.rs_entries_bits = self.gen_params.max_rs_entries_bits + self.m = SimpleTestCircuit(RS(self.gen_params, 2**self.rs_entries_bits, None)) self.insert_list = [ { "rs_entry_id": id, @@ -79,9 +46,9 @@ def test_insert(self): "pc": id, }, } - for id in range(2**self.m.rs_entries_bits) + for id in range(2**self.rs_entries_bits) ] - self.check_list = create_check_list(self.m.rs_entries_bits, self.insert_list) + self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.simulation_process) @@ -89,21 +56,22 @@ def test_insert(self): def simulation_process(self): # After each insert, entry should be marked as full for index, record in enumerate(self.insert_list): - self.assertEqual((yield self.m.rs.data[index].rec_full), 0) - yield from self.m.io_insert.call(record) + self.assertEqual((yield self.m._dut.data[index].rec_full), 0) + yield from self.m.insert.call(record) yield Settle() - self.assertEqual((yield self.m.rs.data[index].rec_full), 1) + self.assertEqual((yield self.m._dut.data[index].rec_full), 1) yield Settle() # Check data integrity - for expected, record in zip(self.check_list, self.m.rs.data): + for expected, record in zip(self.check_list, self.m._dut.data): self.assertEqual(expected, (yield from get_outputs(record))) class TestRSMethodSelect(TestCaseWithSimulator): def test_select(self): self.gen_params = GenParams(test_core_config) - self.m = TestElaboratable(self.gen_params) + self.rs_entries_bits = self.gen_params.max_rs_entries_bits + self.m = SimpleTestCircuit(RS(self.gen_params, 2**self.rs_entries_bits, None)) self.insert_list = [ { "rs_entry_id": id, @@ -123,9 +91,9 @@ def test_select(self): "pc": id, }, } - for id in range(2**self.m.rs_entries_bits - 1) + for id in range(2**self.rs_entries_bits - 1) ] - self.check_list = create_check_list(self.m.rs_entries_bits, self.insert_list) + self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.simulation_process) @@ -133,40 +101,41 @@ def test_select(self): def simulation_process(self): # In the beginning the select method should be ready and id should be selectable for index, record in enumerate(self.insert_list): - self.assertEqual((yield self.m.rs.select.ready), 1) - self.assertEqual((yield from self.m.io_select.call())["rs_entry_id"], index) + self.assertEqual((yield self.m._dut.select.ready), 1) + self.assertEqual((yield from self.m.select.call())["rs_entry_id"], index) yield Settle() - self.assertEqual((yield self.m.rs.data[index].rec_reserved), 1) - yield from self.m.io_insert.call(record) + self.assertEqual((yield self.m._dut.data[index].rec_reserved), 1) + yield from self.m.insert.call(record) yield Settle() # Check if RS state is as expected - for expected, record in zip(self.check_list, self.m.rs.data): + for expected, record in zip(self.check_list, self.m._dut.data): self.assertEqual((yield record.rec_full), expected["rec_full"]) self.assertEqual((yield record.rec_ready), expected["rec_ready"]) self.assertEqual((yield record.rec_reserved), expected["rec_reserved"]) # Reserve the last entry, then select ready should be false - self.assertEqual((yield self.m.rs.select.ready), 1) - self.assertEqual((yield from self.m.io_select.call())["rs_entry_id"], 3) + self.assertEqual((yield self.m._dut.select.ready), 1) + self.assertEqual((yield from self.m.select.call())["rs_entry_id"], 3) yield Settle() - self.assertEqual((yield self.m.rs.select.ready), 0) + self.assertEqual((yield self.m._dut.select.ready), 0) # After take, select ready should be true, with 0 index returned - yield from self.m.io_take.call() + yield from self.m.take.call() yield Settle() - self.assertEqual((yield self.m.rs.select.ready), 1) - self.assertEqual((yield from self.m.io_select.call())["rs_entry_id"], 0) + self.assertEqual((yield self.m._dut.select.ready), 1) + self.assertEqual((yield from self.m.select.call())["rs_entry_id"], 0) # After reservation, select is false again yield Settle() - self.assertEqual((yield self.m.rs.select.ready), 0) + self.assertEqual((yield self.m._dut.select.ready), 0) class TestRSMethodUpdate(TestCaseWithSimulator): def test_update(self): self.gen_params = GenParams(test_core_config) - self.m = TestElaboratable(self.gen_params) + self.rs_entries_bits = self.gen_params.max_rs_entries_bits + self.m = SimpleTestCircuit(RS(self.gen_params, 2**self.rs_entries_bits, None)) self.insert_list = [ { "rs_entry_id": id, @@ -186,9 +155,9 @@ def test_update(self): "pc": id, }, } - for id in range(2**self.m.rs_entries_bits) + for id in range(2**self.rs_entries_bits) ] - self.check_list = create_check_list(self.m.rs_entries_bits, self.insert_list) + self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.simulation_process) @@ -196,29 +165,29 @@ def test_update(self): def simulation_process(self): # Insert all reacords for record in self.insert_list: - yield from self.m.io_insert.call(record) + yield from self.m.insert.call(record) yield Settle() # Check data integrity - for expected, record in zip(self.check_list, self.m.rs.data): + for expected, record in zip(self.check_list, self.m._dut.data): self.assertEqual(expected, (yield from get_outputs(record))) # Update second entry first SP, instruction should be not ready value_sp1 = 1010 - self.assertEqual((yield self.m.rs.data[1].rec_ready), 0) - yield from self.m.io_update.call(reg_id=2, reg_val=value_sp1) + self.assertEqual((yield self.m._dut.data[1].rec_ready), 0) + yield from self.m.update.call(reg_id=2, reg_val=value_sp1) yield Settle() - self.assertEqual((yield self.m.rs.data[1].rs_data.rp_s1), 0) - self.assertEqual((yield self.m.rs.data[1].rs_data.s1_val), value_sp1) - self.assertEqual((yield self.m.rs.data[1].rec_ready), 0) + self.assertEqual((yield self.m._dut.data[1].rs_data.rp_s1), 0) + self.assertEqual((yield self.m._dut.data[1].rs_data.s1_val), value_sp1) + self.assertEqual((yield self.m._dut.data[1].rec_ready), 0) # Update second entry second SP, instruction should be ready value_sp2 = 2020 - yield from self.m.io_update.call(reg_id=3, reg_val=value_sp2) + yield from self.m.update.call(reg_id=3, reg_val=value_sp2) yield Settle() - self.assertEqual((yield self.m.rs.data[1].rs_data.rp_s2), 0) - self.assertEqual((yield self.m.rs.data[1].rs_data.s2_val), value_sp2) - self.assertEqual((yield self.m.rs.data[1].rec_ready), 1) + self.assertEqual((yield self.m._dut.data[1].rs_data.rp_s2), 0) + self.assertEqual((yield self.m._dut.data[1].rs_data.s2_val), value_sp2) + self.assertEqual((yield self.m._dut.data[1].rec_ready), 1) # Insert new instruction to entries 0 and 1, check if update of multiple registers works reg_id = 4 @@ -239,24 +208,25 @@ def simulation_process(self): } for index in range(2): - yield from self.m.io_insert.call(rs_entry_id=index, rs_data=data) + yield from self.m.insert.call(rs_entry_id=index, rs_data=data) yield Settle() - self.assertEqual((yield self.m.rs.data[index].rec_ready), 0) + self.assertEqual((yield self.m._dut.data[index].rec_ready), 0) - yield from self.m.io_update.call(reg_id=reg_id, reg_val=value_spx) + yield from self.m.update.call(reg_id=reg_id, reg_val=value_spx) yield Settle() for index in range(2): - self.assertEqual((yield self.m.rs.data[index].rs_data.rp_s1), 0) - self.assertEqual((yield self.m.rs.data[index].rs_data.rp_s2), 0) - self.assertEqual((yield self.m.rs.data[index].rs_data.s1_val), value_spx) - self.assertEqual((yield self.m.rs.data[index].rs_data.s2_val), value_spx) - self.assertEqual((yield self.m.rs.data[index].rec_ready), 1) + self.assertEqual((yield self.m._dut.data[index].rs_data.rp_s1), 0) + self.assertEqual((yield self.m._dut.data[index].rs_data.rp_s2), 0) + self.assertEqual((yield self.m._dut.data[index].rs_data.s1_val), value_spx) + self.assertEqual((yield self.m._dut.data[index].rs_data.s2_val), value_spx) + self.assertEqual((yield self.m._dut.data[index].rec_ready), 1) class TestRSMethodTake(TestCaseWithSimulator): def test_take(self): self.gen_params = GenParams(test_core_config) - self.m = TestElaboratable(self.gen_params) + self.rs_entries_bits = self.gen_params.max_rs_entries_bits + self.m = SimpleTestCircuit(RS(self.gen_params, 2**self.rs_entries_bits, None)) self.insert_list = [ { "rs_entry_id": id, @@ -276,9 +246,9 @@ def test_take(self): "pc": id, }, } - for id in range(2**self.m.rs_entries_bits) + for id in range(2**self.rs_entries_bits) ] - self.check_list = create_check_list(self.m.rs_entries_bits, self.insert_list) + self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.simulation_process) @@ -286,32 +256,32 @@ def test_take(self): def simulation_process(self): # After each insert, entry should be marked as full for record in self.insert_list: - yield from self.m.io_insert.call(record) + yield from self.m.insert.call(record) yield Settle() # Check data integrity - for expected, record in zip(self.check_list, self.m.rs.data): + for expected, record in zip(self.check_list, self.m._dut.data): self.assertEqual(expected, (yield from get_outputs(record))) # Take first instruction - self.assertEqual((yield self.m.rs.take.ready), 1) - data = yield from self.m.io_take.call(rs_entry_id=0) + self.assertEqual((yield self.m._dut.take.ready), 1) + data = yield from self.m.take.call(rs_entry_id=0) for key in data: self.assertEqual(data[key], self.check_list[0]["rs_data"][key]) yield Settle() - self.assertEqual((yield self.m.rs.take.ready), 0) + self.assertEqual((yield self.m._dut.take.ready), 0) # Update second instuction and take it reg_id = 2 value_spx = 1 - yield from self.m.io_update.call(reg_id=reg_id, reg_val=value_spx) + yield from self.m.update.call(reg_id=reg_id, reg_val=value_spx) yield Settle() - self.assertEqual((yield self.m.rs.take.ready), 1) - data = yield from self.m.io_take.call(rs_entry_id=1) + self.assertEqual((yield self.m._dut.take.ready), 1) + data = yield from self.m.take.call(rs_entry_id=1) for key in data: self.assertEqual(data[key], self.check_list[1]["rs_data"][key]) yield Settle() - self.assertEqual((yield self.m.rs.take.ready), 0) + self.assertEqual((yield self.m._dut.take.ready), 0) # Insert two new ready instructions and take them reg_id = 0 @@ -333,28 +303,29 @@ def simulation_process(self): } for index in range(2): - yield from self.m.io_insert.call(rs_entry_id=index, rs_data=entry_data) + yield from self.m.insert.call(rs_entry_id=index, rs_data=entry_data) yield Settle() - self.assertEqual((yield self.m.rs.data[index].rec_ready), 1) - self.assertEqual((yield self.m.rs.take.ready), 1) + self.assertEqual((yield self.m._dut.data[index].rec_ready), 1) + self.assertEqual((yield self.m._dut.take.ready), 1) - data = yield from self.m.io_take.call(rs_entry_id=0) + data = yield from self.m.take.call(rs_entry_id=0) for key in data: self.assertEqual(data[key], entry_data[key]) yield Settle() - self.assertEqual((yield self.m.rs.take.ready), 1) + self.assertEqual((yield self.m._dut.take.ready), 1) - data = yield from self.m.io_take.call(rs_entry_id=1) + data = yield from self.m.take.call(rs_entry_id=1) for key in data: self.assertEqual(data[key], entry_data[key]) yield Settle() - self.assertEqual((yield self.m.rs.take.ready), 0) + self.assertEqual((yield self.m._dut.take.ready), 0) class TestRSMethodGetReadyList(TestCaseWithSimulator): def test_get_ready_list(self): self.gen_params = GenParams(test_core_config) - self.m = TestElaboratable(self.gen_params) + self.rs_entries_bits = self.gen_params.max_rs_entries_bits + self.m = SimpleTestCircuit(RS(self.gen_params, 2**self.rs_entries_bits, None)) self.insert_list = [ { "rs_entry_id": id, @@ -374,9 +345,9 @@ def test_get_ready_list(self): "pc": id, }, } - for id in range(2**self.m.rs_entries_bits) + for id in range(2**self.rs_entries_bits) ] - self.check_list = create_check_list(self.m.rs_entries_bits, self.insert_list) + self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.simulation_process) @@ -384,30 +355,34 @@ def test_get_ready_list(self): def simulation_process(self): # After each insert, entry should be marked as full for record in self.insert_list: - yield from self.m.io_insert.call(record) + yield from self.m.insert.call(record) yield Settle() # Check ready vector integrity - ready_list = (yield from self.m.io_get_ready_list[0].call())["ready_list"] + ready_list = (yield from self.m.get_ready_list[0].call())["ready_list"] self.assertEqual(ready_list, 0b0011) # Take first record and check ready vector integrity - yield from self.m.io_take.call(rs_entry_id=0) + yield from self.m.take.call(rs_entry_id=0) yield Settle() - ready_list = (yield from self.m.io_get_ready_list[0].call())["ready_list"] + ready_list = (yield from self.m.get_ready_list[0].call())["ready_list"] self.assertEqual(ready_list, 0b0010) # Take second record and check ready vector integrity - yield from self.m.io_take.call(rs_entry_id=1) + yield from self.m.take.call(rs_entry_id=1) yield Settle() - option_ready_list = yield from self.m.io_get_ready_list[0].call_try() + option_ready_list = yield from self.m.get_ready_list[0].call_try() self.assertIsNone(option_ready_list) class TestRSMethodTwoGetReadyLists(TestCaseWithSimulator): def test_two_get_ready_lists(self): self.gen_params = GenParams(test_core_config) - self.m = TestElaboratable(self.gen_params, [[OpType(1), OpType(2)], [OpType(3), OpType(4)]]) + self.rs_entries = self.gen_params.max_rs_entries + self.rs_entries_bits = self.gen_params.max_rs_entries_bits + self.m = SimpleTestCircuit( + RS(self.gen_params, 2**self.rs_entries_bits, [[OpType(1), OpType(2)], [OpType(3), OpType(4)]]) + ) self.insert_list = [ { "rs_entry_id": id, @@ -426,9 +401,9 @@ def test_two_get_ready_lists(self): "imm": id, }, } - for id in range(self.m.rs_entries) + for id in range(self.rs_entries) ] - self.check_list = create_check_list(self.m.rs_entries_bits, self.insert_list) + self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.simulation_process) @@ -436,24 +411,24 @@ def test_two_get_ready_lists(self): def simulation_process(self): # After each insert, entry should be marked as full for record in self.insert_list: - yield from self.m.io_insert.call(record) + yield from self.m.insert.call(record) yield Settle() masks = [0b0011, 0b1100] - for i in range(self.m.rs.rs_entries + 1): + for i in range(self.m._dut.rs_entries + 1): # Check ready vectors' integrity for j in range(2): - ready_list = yield from self.m.io_get_ready_list[j].call_try() + ready_list = yield from self.m.get_ready_list[j].call_try() if masks[j]: self.assertEqual(ready_list, {"ready_list": masks[j]}) else: self.assertIsNone(ready_list) # Take a record - if i == self.m.rs.rs_entries: + if i == self.m._dut.rs_entries: break - yield from self.m.io_take.call(rs_entry_id=i) + yield from self.m.take.call(rs_entry_id=i) yield Settle() masks = [mask & ~(1 << i) for mask in masks] diff --git a/test/test_core.py b/test/test_core.py index 5d8625e13..934a28449 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -33,7 +33,7 @@ from riscvmodel.variant import RV32I -class TestElaboratable(Elaboratable): +class CoreTestElaboratable(Elaboratable): def __init__(self, gen_params: GenParams, instr_mem: list[int] = [0], data_mem: Optional[list[int]] = None): self.gen_params = gen_params self.instr_mem = instr_mem @@ -84,7 +84,7 @@ def gen_riscv_lui_instr(dst, imm): class TestCoreBase(TestCaseWithSimulator): gen_params: GenParams - m: TestElaboratable + m: CoreTestElaboratable def check_RAT_alloc(self, rat, expected_alloc_count=None): # noqa: N802 allocated = [] @@ -183,7 +183,7 @@ def simple_test(self): def test_simple(self): self.gen_params = GenParams(basic_core_config) - m = TestElaboratable(self.gen_params) + m = CoreTestElaboratable(self.gen_params) self.m = m with self.run_simulation(m) as sim: @@ -242,7 +242,7 @@ def test_randomized(self): self.instr_mem = list(map(lambda x: x.encode(), all_instr)) - m = TestElaboratable(self.gen_params, instr_mem=self.instr_mem) + m = CoreTestElaboratable(self.gen_params, instr_mem=self.instr_mem) self.m = m with self.run_simulation(m) as sim: @@ -323,7 +323,7 @@ def test_asm_source(self): self.gen_params = GenParams(self.configuration) bin_src = self.prepare_source(self.source_file) - self.m = TestElaboratable(self.gen_params, instr_mem=bin_src) + self.m = CoreTestElaboratable(self.gen_params, instr_mem=bin_src) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.run_and_check) @@ -397,6 +397,6 @@ def run_with_interrupt(self): def test_interrupted_prog(self): bin_src = self.prepare_source(self.source_file) - self.m = TestElaboratable(self.gen_params, instr_mem=bin_src) + self.m = CoreTestElaboratable(self.gen_params, instr_mem=bin_src) with self.run_simulation(self.m) as sim: sim.add_sync_process(self.run_with_interrupt) diff --git a/test/transactions/test_adapter.py b/test/transactions/test_adapter.py index fbb779a12..48728cb02 100644 --- a/test/transactions/test_adapter.py +++ b/test/transactions/test_adapter.py @@ -1,10 +1,9 @@ from amaranth import * from transactron import Method, def_method, TModule -from transactron.lib import AdapterTrans -from ..common import TestCaseWithSimulator, TestbenchIO, data_layout +from ..common import TestCaseWithSimulator, data_layout, SimpleTestCircuit, ModuleConnector class Echo(Elaboratable): @@ -45,35 +44,19 @@ def _(arg): return m -class TestElaboratable(Elaboratable): - def __init__(self): - self.echo = Echo() - self.consumer = Consumer() - self.io_echo = TestbenchIO(AdapterTrans(self.echo.action)) - self.io_consume = TestbenchIO(AdapterTrans(self.consumer.action)) - - def elaborate(self, platform): - m = TModule() - - m.submodules.echo = self.echo - m.submodules.io_echo = self.io_echo - m.submodules.consumer = self.consumer - m.submodules.io_consume = self.io_consume - - return m - - class TestAdapterTrans(TestCaseWithSimulator): def proc(self): for _ in range(3): # this would previously timeout if the output layout was empty (as is in this case) - yield from self.t.io_consume.call() + yield from self.consumer.action.call() for expected in [4, 1, 0]: - obtained = (yield from self.t.io_echo.call(data=expected))["data"] + obtained = (yield from self.echo.action.call(data=expected))["data"] self.assertEqual(expected, obtained) def test_single(self): - self.t = t = TestElaboratable() + self.echo = SimpleTestCircuit(Echo()) + self.consumer = SimpleTestCircuit(Consumer()) + self.m = ModuleConnector(echo=self.echo, consumer=self.consumer) - with self.run_simulation(t, max_cycles=100) as sim: + with self.run_simulation(self.m, max_cycles=100) as sim: sim.add_sync_process(self.proc) diff --git a/transactron/utils/amaranth_ext/elaboratables.py b/transactron/utils/amaranth_ext/elaboratables.py index 2f6e1ede9..b60232e43 100644 --- a/transactron/utils/amaranth_ext/elaboratables.py +++ b/transactron/utils/amaranth_ext/elaboratables.py @@ -10,6 +10,7 @@ "OneHotSwitch", "ModuleConnector", "Scheduler", + "RoundRobin", ] @@ -183,3 +184,58 @@ def elaborate(self, platform): m.d.sync += grant_reg.eq(self.grant) return m + + +class RoundRobin(Elaboratable): + """Round-robin scheduler. + For a given set of requests, the round-robin scheduler will + grant one request. Once it grants a request, if any other + requests are active, it grants the next active request with + a greater number, restarting from zero once it reaches the + highest one. + Use :class:`EnableInserter` to control when the scheduler + is updated. + + Implementation ported from amaranth lib. + + Parameters + ---------- + count : int + Number of requests. + Attributes + ---------- + requests : Signal(count), in + Set of requests. + grant : Signal(range(count)), out + Number of the granted request. Does not change if there are no + active requests. + valid : Signal(), out + Asserted if grant corresponds to an active request. Deasserted + otherwise, i.e. if no requests are active. + """ + + def __init__(self, *, count): + if not isinstance(count, int) or count < 0: + raise ValueError("Count must be a non-negative integer, not {!r}".format(count)) + self.count = count + + self.requests = Signal(count) + self.grant = Signal(range(count)) + self.valid = Signal() + + def elaborate(self, platform): + m = Module() + + with m.Switch(self.grant): + for i in range(self.count): + with m.Case(i): + for pred in reversed(range(i)): + with m.If(self.requests[pred]): + m.d.sync += self.grant.eq(pred) + for succ in reversed(range(i + 1, self.count)): + with m.If(self.requests[succ]): + m.d.sync += self.grant.eq(succ) + + m.d.sync += self.valid.eq(self.requests.any()) + + return m