From 240734d0ca5a8c42feaf8ad5c70887f53008bb65 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Tue, 26 Nov 2024 20:53:18 +0100 Subject: [PATCH] RF using async memory (#759) --- coreblocks/core_structs/rf.py | 35 ++++++++++++++++++++--------------- test/test_core.py | 3 ++- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/coreblocks/core_structs/rf.py b/coreblocks/core_structs/rf.py index 6865219b0..e22e97f98 100644 --- a/coreblocks/core_structs/rf.py +++ b/coreblocks/core_structs/rf.py @@ -1,10 +1,10 @@ from amaranth import * +import amaranth.lib.memory as memory from transactron import Method, Transaction, def_method, TModule from coreblocks.interface.layouts import RFLayouts from coreblocks.params import GenParams from transactron.lib.metrics import HwExpHistogram, TaggedLatencyMeasurer from transactron.utils.amaranth_ext.functions import popcount -from transactron.utils.transactron_helpers import make_layout __all__ = ["RegisterFile"] @@ -13,12 +13,9 @@ class RegisterFile(Elaboratable): def __init__(self, *, gen_params: GenParams): self.gen_params = gen_params layouts = gen_params.get(RFLayouts) - self.internal_layout = make_layout(("reg_val", gen_params.isa.xlen), ("valid", 1)) self.read_layout = layouts.rf_read_out - self.entries = Array( - Signal(self.internal_layout, init={"reg_val": 0, "valid": k == 0}) - for k in range(2**gen_params.phys_regs_bits) - ) + self.entries = memory.Memory(shape=gen_params.isa.xlen, depth=2**gen_params.phys_regs_bits, init=[]) + self.valids = Array(Signal(init=k == 0) for k in range(2**gen_params.phys_regs_bits)) self.read1 = Method(i=layouts.rf_read_in, o=layouts.rf_read_out) self.read2 = Method(i=layouts.rf_read_in, o=layouts.rf_read_out) @@ -41,27 +38,33 @@ def __init__(self, *, gen_params: GenParams): def elaborate(self, platform): m = TModule() - m.submodules += [self.perf_rf_valid_time, self.perf_num_valid] + m.submodules += [self.entries, self.perf_rf_valid_time, self.perf_num_valid] being_written = Signal(self.gen_params.phys_regs_bits) written_value = Signal(self.gen_params.isa.xlen) + write_port = self.entries.write_port() + read_port_1 = self.entries.read_port(domain="comb") + read_port_2 = self.entries.read_port(domain="comb") + @def_method(m, self.read1) def _(reg_id: Value): forward = Signal() m.d.av_comb += forward.eq((being_written == reg_id) & (reg_id != 0)) + m.d.av_comb += read_port_1.addr.eq(reg_id) return { - "reg_val": Mux(forward, written_value, self.entries[reg_id].reg_val), - "valid": Mux(forward, 1, self.entries[reg_id].valid), + "reg_val": Mux(forward, written_value, read_port_1.data), + "valid": Mux(forward, 1, self.valids[reg_id]), } @def_method(m, self.read2) def _(reg_id: Value): forward = Signal() m.d.av_comb += forward.eq((being_written == reg_id) & (reg_id != 0)) + m.d.av_comb += read_port_2.addr.eq(reg_id) return { - "reg_val": Mux(forward, written_value, self.entries[reg_id].reg_val), - "valid": Mux(forward, 1, self.entries[reg_id].valid), + "reg_val": Mux(forward, written_value, read_port_2.data), + "valid": Mux(forward, 1, self.valids[reg_id]), } @def_method(m, self.write) @@ -69,21 +72,23 @@ def _(reg_id: Value, reg_val: Value): zero_reg = reg_id == 0 m.d.comb += being_written.eq(reg_id) m.d.av_comb += written_value.eq(reg_val) + m.d.av_comb += write_port.addr.eq(reg_id) + m.d.av_comb += write_port.data.eq(reg_val) with m.If(~(zero_reg)): - m.d.sync += self.entries[reg_id].reg_val.eq(reg_val) - m.d.sync += self.entries[reg_id].valid.eq(1) + m.d.comb += write_port.en.eq(1) + m.d.sync += self.valids[reg_id].eq(1) self.perf_rf_valid_time.start(m, slot=reg_id) @def_method(m, self.free) def _(reg_id: Value): with m.If(reg_id != 0): - m.d.sync += self.entries[reg_id].valid.eq(0) + m.d.sync += self.valids[reg_id].eq(0) self.perf_rf_valid_time.stop(m, slot=reg_id) if self.perf_num_valid.metrics_enabled(): num_valid = Signal(self.gen_params.phys_regs_bits + 1) m.d.comb += num_valid.eq( - popcount(Cat(self.entries[reg_id].valid for reg_id in range(2**self.gen_params.phys_regs_bits))) + popcount(Cat(self.valids[reg_id] for reg_id in range(2**self.gen_params.phys_regs_bits))) ) with Transaction(name="perf").body(m): self.perf_num_valid.add(m, num_valid) diff --git a/test/test_core.py b/test/test_core.py index 237360f53..92dba973e 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -70,7 +70,8 @@ def get_phys_reg_rrat(self, sim: TestbenchContext, reg_id): return sim.get(self.m.core.RRAT.entries[reg_id]) def get_arch_reg_val(self, sim: TestbenchContext, reg_id): - return sim.get(self.m.core.RF.entries[(self.get_phys_reg_rrat(sim, reg_id))].reg_val) + # TODO: better stubs for memory, remove ignore + return sim.get(self.m.core.RF.entries.data[(self.get_phys_reg_rrat(sim, reg_id))]) # type: ignore class TestCoreAsmSourceBase(TestCoreBase):