Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronous register file using MemoryBank #765

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions coreblocks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ def elaborate(self, platform):
get_free_reg=free_rf_fifo.read,
rat_rename=frat.rename,
rob_put=rob.put,
rf_read1=rf.read1,
rf_read2=rf.read2,
rf_read_req1=rf.read_req1,
rf_read_req2=rf.read_req2,
rf_read_resp1=rf.read_resp1,
rf_read_resp2=rf.read_resp2,
reservation_stations=self.func_blocks_unifier.rs_blocks,
gen_params=self.gen_params,
)
Expand Down
39 changes: 23 additions & 16 deletions coreblocks/core_structs/rf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from amaranth import *
import amaranth.lib.memory as memory
from transactron import Method, Transaction, def_method, TModule
from coreblocks.interface.layouts import RFLayouts
from coreblocks.params import GenParams
from transactron.lib.metrics import HwExpHistogram, TaggedLatencyMeasurer
from transactron.lib.storage import MemoryBank
from transactron.utils.amaranth_ext.functions import popcount

__all__ = ["RegisterFile"]
Expand All @@ -14,11 +14,18 @@ def __init__(self, *, gen_params: GenParams):
self.gen_params = gen_params
layouts = gen_params.get(RFLayouts)
self.read_layout = layouts.rf_read_out
self.entries = memory.Memory(shape=gen_params.isa.xlen, depth=2**gen_params.phys_regs_bits, init=[])
self.entries = MemoryBank(
data_layout=[("data", gen_params.isa.xlen)],
elem_count=2**gen_params.phys_regs_bits,
read_ports=2,
transparent=True,
)
self.valids = Array(Signal(init=k == 0) for k in range(2**gen_params.phys_regs_bits))

self.read1 = Method(i=layouts.rf_read_in, o=layouts.rf_read_out)
self.read2 = Method(i=layouts.rf_read_in, o=layouts.rf_read_out)
self.read_req1 = Method(i=layouts.rf_read_in)
self.read_req2 = Method(i=layouts.rf_read_in)
self.read_resp1 = Method(i=layouts.rf_read_in, o=layouts.rf_read_out)
self.read_resp2 = Method(i=layouts.rf_read_in, o=layouts.rf_read_out)
self.write = Method(i=layouts.rf_write)
self.free = Method(i=layouts.rf_free)

Expand All @@ -43,27 +50,29 @@ def elaborate(self, platform):
being_written = Signal(self.gen_params.phys_regs_bits)
written_value = Signal(self.gen_params.isa.xlen)

write_port = self.entries.write_port()
read_port_1 = self.entries.read_port(domain="comb")
read_port_2 = self.entries.read_port(domain="comb")
@def_method(m, self.read_req1)
def _(reg_id: Value):
self.entries.read_req[0](m, addr=reg_id)

@def_method(m, self.read_req2)
def _(reg_id: Value):
self.entries.read_req[1](m, addr=reg_id)

@def_method(m, self.read1)
@def_method(m, self.read_resp1)
def _(reg_id: Value):
forward = Signal()
m.d.av_comb += forward.eq((being_written == reg_id) & (reg_id != 0))
m.d.av_comb += read_port_1.addr.eq(reg_id)
return {
"reg_val": Mux(forward, written_value, read_port_1.data),
"reg_val": Mux(forward, written_value, self.entries.read_resp[0](m).data),
"valid": Mux(forward, 1, self.valids[reg_id]),
}

@def_method(m, self.read2)
@def_method(m, self.read_resp2)
def _(reg_id: Value):
forward = Signal()
m.d.av_comb += forward.eq((being_written == reg_id) & (reg_id != 0))
m.d.av_comb += read_port_2.addr.eq(reg_id)
return {
"reg_val": Mux(forward, written_value, read_port_2.data),
"reg_val": Mux(forward, written_value, self.entries.read_resp[1](m).data),
"valid": Mux(forward, 1, self.valids[reg_id]),
}

Expand All @@ -72,10 +81,8 @@ def _(reg_id: Value, reg_val: Value):
zero_reg = reg_id == 0
m.d.comb += being_written.eq(reg_id)
m.d.av_comb += written_value.eq(reg_val)
m.d.av_comb += write_port.addr.eq(reg_id)
m.d.av_comb += write_port.data.eq(reg_val)
with m.If(~(zero_reg)):
m.d.comb += write_port.en.eq(1)
self.entries.write(m, addr=reg_id, data={"data": reg_val})
m.d.sync += self.valids[reg_id].eq(1)
self.perf_rf_valid_time.start(m, slot=reg_id)

Expand Down
57 changes: 41 additions & 16 deletions coreblocks/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def __init__(
get_instr: Method,
push_instr: Method,
rs_select: Sequence[tuple[Method, set[OpType]]],
rf_read_req1: Method,
rf_read_req2: Method,
gen_params: GenParams
):
"""
Expand All @@ -202,6 +204,12 @@ def __init__(

- A method used for allocating an entry in the RS. Uses `RSLayouts.select_out`.
- A set of `OpType`\\s that can be handled by this RS.
rf_read_req1: Method
Method used for requesting value of first source register and information if it is valid.
Uses `RFLayouts.rf_read_out`.
rf_read_req2: Method
Method used for requesting value of second source register and information if it is valid.
Uses `RFLayouts.rf_read_out`.
gen_params: GenParams
Core generation parameters.
"""
Expand All @@ -214,6 +222,8 @@ def __init__(
self.get_instr = get_instr
self.rs_select = rs_select
self.push_instr = push_instr
self.rf_read_req1 = rf_read_req1
self.rf_read_req2 = rf_read_req2

def decode_optype_set(self, optypes: set[OpType]) -> int:
res = 0x0
Expand Down Expand Up @@ -242,6 +252,9 @@ def elaborate(self, platform):

self.push_instr(m, data_out)

self.rf_read_req1(m, instr.regs_p.rp_s1)
self.rf_read_req2(m, instr.regs_p.rp_s2)

return m


Expand All @@ -258,8 +271,8 @@ def __init__(
*,
get_instr: Method,
rs_insert: Sequence[Method],
rf_read1: Method,
rf_read2: Method,
rf_read_resp1: Method,
rf_read_resp2: Method,
gen_params: GenParams
):
"""
Expand All @@ -270,10 +283,10 @@ def __init__(
rs_insert: Sequence[Method]
Sequence of methods used for pushing an instruction into the RS. Ordering of this list
determines the ID of a specific RS. They use `RSLayouts.insert_in`
rf_read1: Method
rf_read_resp1: Method
Method used for getting value of first source register and information if it is valid.
Uses `RFLayouts.rf_read_out` and `RFLayouts.rf_read_in`.
rf_read2: Method
rf_read_resp2: Method
Method used for getting value of second source register and information if it is valid.
Uses `RFLayouts.rf_read_out` and `RFLayouts.rf_read_in`.
gen_params: GenParams
Expand All @@ -283,8 +296,8 @@ def __init__(

self.get_instr = get_instr
self.rs_insert = rs_insert
self.rf_read1 = rf_read1
self.rf_read2 = rf_read2
self.rf_read_resp1 = rf_read_resp1
self.rf_read_resp2 = rf_read_resp2

def elaborate(self, platform):
m = TModule()
Expand All @@ -293,8 +306,8 @@ def elaborate(self, platform):
# therefore we can use single transaction here.
with Transaction().body(m):
instr = self.get_instr(m)
source1 = self.rf_read1(m, reg_id=instr.regs_p.rp_s1)
source2 = self.rf_read2(m, reg_id=instr.regs_p.rp_s2)
source1 = self.rf_read_resp1(m, reg_id=instr.regs_p.rp_s1)
source2 = self.rf_read_resp2(m, reg_id=instr.regs_p.rp_s2)

# when core is flushed, rp_dst are discarded.
# source operands may never become ready, skip waiting for them in any in RSes/FBs.
Expand Down Expand Up @@ -357,8 +370,10 @@ def __init__(
get_free_reg: Method,
rat_rename: Method,
rob_put: Method,
rf_read1: Method,
rf_read2: Method,
rf_read_req1: Method,
rf_read_req2: Method,
rf_read_resp1: Method,
rf_read_resp2: Method,
reservation_stations: Sequence[tuple[FuncBlock, set[OpType]]],
gen_params: GenParams
):
Expand All @@ -375,10 +390,16 @@ def __init__(
and `RATLayouts.rat_rename_out`.
rob_put: Method
Method used for getting a free entry in ROB. Uses `ROBLayouts.data_layout`.
rf_read1: Method
rf_read_req1: Method
Method used for requesting value of first source register and information if it is valid.
Uses `RFLayouts.rf_read_out`.
rf_read_req2: Method
Method used for requesting value of second source register and information if it is valid.
Uses `RFLayouts.rf_read_out`.
rf_read_resp1: Method
Method used for getting value of first source register and information if it is valid.
Uses `RFLayouts.rf_read_out` and `RFLayouts.rf_read_in`.
rf_read2: Method
rf_read_resp2: Method
Method used for getting value of second source register and information if it is valid.
Uses `RFLayouts.rf_read_out` and `RFLayouts.rf_read_in`.
reservation_stations: Sequence[FuncBlock]
Expand All @@ -392,8 +413,10 @@ def __init__(
self.get_free_reg = get_free_reg
self.rat_rename = rat_rename
self.rob_put = rob_put
self.rf_read1 = rf_read1
self.rf_read2 = rf_read2
self.rf_read_req1 = rf_read_req1
self.rf_read_req2 = rf_read_req2
self.rf_read_resp1 = rf_read_resp1
self.rf_read_resp2 = rf_read_resp2
self.rs = reservation_stations

def elaborate(self, platform):
Expand Down Expand Up @@ -429,13 +452,15 @@ def elaborate(self, platform):
get_instr=reg_alloc_out_buf.read,
rs_select=[(rs.select, optypes) for rs, optypes in self.rs],
push_instr=rs_select_out_buf.write,
rf_read_req1=self.rf_read_req1,
rf_read_req2=self.rf_read_req2,
)

m.submodules.rs_insertion = RSInsertion(
get_instr=rs_select_out_buf.read,
rs_insert=[rs.insert for rs, _ in self.rs],
rf_read1=self.rf_read1,
rf_read2=self.rf_read2,
rf_read_resp1=self.rf_read_resp1,
rf_read_resp2=self.rf_read_resp2,
gen_params=self.gen_params,
)

Expand Down
15 changes: 14 additions & 1 deletion test/scheduler/test_rs_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from amaranth import *

from coreblocks.params import GenParams
from coreblocks.interface.layouts import RSLayouts, SchedulerLayouts
from coreblocks.interface.layouts import RFLayouts, RSLayouts, SchedulerLayouts
from coreblocks.arch import Funct3, Funct7
from coreblocks.arch import OpType
from coreblocks.params.configurations import test_core_config
Expand All @@ -26,6 +26,7 @@ def elaborate(self, platform):
m = Module()

rs_layouts = self.gen_params.get(RSLayouts, rs_entries_bits=self.gen_params.max_rs_entries_bits)
rf_layouts = self.gen_params.get(RFLayouts)
scheduler_layouts = self.gen_params.get(SchedulerLayouts)

# data structures
Expand All @@ -37,13 +38,17 @@ def elaborate(self, platform):
m.submodules.instr_out = self.instr_out = TestbenchIO(AdapterTrans(out_fifo.read))
m.submodules.rs1_alloc = self.rs1_alloc = TestbenchIO(Adapter(o=rs_layouts.rs.select_out))
m.submodules.rs2_alloc = self.rs2_alloc = TestbenchIO(Adapter(o=rs_layouts.rs.select_out))
m.submodules.rf_read_req1 = self.rf_read_req1 = TestbenchIO(Adapter(i=rf_layouts.rf_read_in))
m.submodules.rf_read_req2 = self.rf_read_req2 = TestbenchIO(Adapter(i=rf_layouts.rf_read_in))

# rs selector
m.submodules.selector = self.selector = RSSelection(
gen_params=self.gen_params,
get_instr=instr_fifo.read,
rs_select=[(self.rs1_alloc.adapter.iface, _rs1_optypes), (self.rs2_alloc.adapter.iface, _rs2_optypes)],
push_instr=out_fifo.write,
rf_read_req1=self.rf_read_req1.adapter.iface,
rf_read_req2=self.rf_read_req2.adapter.iface,
)

return m
Expand Down Expand Up @@ -114,6 +119,14 @@ def eff():

return process()

@def_method_mock(lambda self: self.m.rf_read_req1)
def rf_read_req1_mock(self, reg_id):
pass

@def_method_mock(lambda self: self.m.rf_read_req2)
def rf_read_req2_mock(self, reg_id):
pass

def create_output_process(self, instr_count: int, random_wait: int = 0):
async def process(sim: TestbenchContext):
for _ in range(instr_count):
Expand Down
6 changes: 4 additions & 2 deletions test/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def elaborate(self, platform):
get_free_reg=free_rf_fifo.read,
rat_rename=rat.rename,
rob_put=self.rob.put,
rf_read1=self.rf.read1,
rf_read2=self.rf.read2,
rf_read_req1=self.rf.read_req1,
rf_read_req2=self.rf.read_req2,
rf_read_resp1=self.rf.read_resp1,
rf_read_resp2=self.rf.read_resp2,
reservation_stations=rs_blocks,
gen_params=self.gen_params,
)
Expand Down
2 changes: 1 addition & 1 deletion test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_phys_reg_rrat(self, sim: TestbenchContext, reg_id):

def get_arch_reg_val(self, sim: TestbenchContext, reg_id):
# TODO: better stubs for memory, remove ignore
return sim.get(self.m.core.RF.entries.data[(self.get_phys_reg_rrat(sim, reg_id))]) # type: ignore
return sim.get(self.m.core.RF.entries.mem.data[(self.get_phys_reg_rrat(sim, reg_id))]) # type: ignore


class TestCoreAsmSourceBase(TestCoreBase):
Expand Down
Loading