Skip to content

Commit

Permalink
RF using async memory
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk committed Nov 25, 2024
1 parent 5e0e7d6 commit 3117ef2
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions coreblocks/core_structs/rf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from amaranth import *
import amaranth.lib.memory as memory
from transactron import Method, Transaction, def_method, TModule
from coreblocks.interface.layouts import RFLayouts
from coreblocks.params import GenParams
Expand All @@ -15,10 +16,8 @@ def __init__(self, *, gen_params: GenParams):
layouts = gen_params.get(RFLayouts)
self.internal_layout = make_layout(("reg_val", gen_params.isa.xlen), ("valid", 1))
self.read_layout = layouts.rf_read_out
self.entries = Array(
Signal(self.internal_layout, init={"reg_val": 0, "valid": k == 0})
for k in range(2**gen_params.phys_regs_bits)
)
self.entries = memory.Memory(shape=gen_params.isa.xlen, depth=2**gen_params.phys_regs_bits, init=[])
self.valids = Array(Signal(init=k == 0) for k in range(2**gen_params.phys_regs_bits))

self.read1 = Method(i=layouts.rf_read_in, o=layouts.rf_read_out)
self.read2 = Method(i=layouts.rf_read_in, o=layouts.rf_read_out)
Expand All @@ -41,49 +40,57 @@ def __init__(self, *, gen_params: GenParams):
def elaborate(self, platform):
m = TModule()

m.submodules += [self.perf_rf_valid_time, self.perf_num_valid]
m.submodules += [self.entries, self.perf_rf_valid_time, self.perf_num_valid]

being_written = Signal(self.gen_params.phys_regs_bits)
written_value = Signal(self.gen_params.isa.xlen)

write_port = self.entries.write_port()
read_port_1 = self.entries.read_port(domain="comb")
read_port_2 = self.entries.read_port(domain="comb")

@def_method(m, self.read1)
def _(reg_id: Value):
forward = Signal()
m.d.av_comb += forward.eq((being_written == reg_id) & (reg_id != 0))
m.d.av_comb += read_port_1.addr.eq(reg_id)
return {
"reg_val": Mux(forward, written_value, self.entries[reg_id].reg_val),
"valid": Mux(forward, 1, self.entries[reg_id].valid),
"reg_val": Mux(forward, written_value, read_port_1.data),
"valid": Mux(forward, 1, self.valids[reg_id]),
}

@def_method(m, self.read2)
def _(reg_id: Value):
forward = Signal()
m.d.av_comb += forward.eq((being_written == reg_id) & (reg_id != 0))
m.d.av_comb += read_port_2.addr.eq(reg_id)
return {
"reg_val": Mux(forward, written_value, self.entries[reg_id].reg_val),
"valid": Mux(forward, 1, self.entries[reg_id].valid),
"reg_val": Mux(forward, written_value, read_port_2.data),
"valid": Mux(forward, 1, self.valids[reg_id]),
}

@def_method(m, self.write)
def _(reg_id: Value, reg_val: Value):
zero_reg = reg_id == 0
m.d.comb += being_written.eq(reg_id)
m.d.av_comb += written_value.eq(reg_val)
m.d.av_comb += write_port.addr.eq(reg_id)
m.d.av_comb += write_port.data.eq(reg_val)
with m.If(~(zero_reg)):
m.d.sync += self.entries[reg_id].reg_val.eq(reg_val)
m.d.sync += self.entries[reg_id].valid.eq(1)
m.d.comb += write_port.en.eq(1)
m.d.sync += self.valids[reg_id].eq(1)
self.perf_rf_valid_time.start(m, slot=reg_id)

@def_method(m, self.free)
def _(reg_id: Value):
with m.If(reg_id != 0):
m.d.sync += self.entries[reg_id].valid.eq(0)
m.d.sync += self.valids[reg_id].eq(0)
self.perf_rf_valid_time.stop(m, slot=reg_id)

if self.perf_num_valid.metrics_enabled():
num_valid = Signal(self.gen_params.phys_regs_bits + 1)
m.d.comb += num_valid.eq(
popcount(Cat(self.entries[reg_id].valid for reg_id in range(2**self.gen_params.phys_regs_bits)))
popcount(Cat(self.valids[reg_id] for reg_id in range(2**self.gen_params.phys_regs_bits)))
)
with Transaction(name="perf").body(m):
self.perf_num_valid.add(m, num_valid)
Expand Down

0 comments on commit 3117ef2

Please sign in to comment.