Skip to content

Commit

Permalink
FIFO reservation station (#634)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored Apr 2, 2024
1 parent 4f25673 commit 8c6128b
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 36 deletions.
28 changes: 28 additions & 0 deletions coreblocks/func_blocks/fu/common/fifo_rs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from amaranth import *
from transactron import TModule
from coreblocks.func_blocks.fu.common.rs import RSBase

__all__ = ["FifoRS"]


class FifoRS(RSBase):
def elaborate(self, platform):
m = TModule()

front = Signal(self.rs_entries_bits)
back = Signal(self.rs_entries_bits)

select_possible = ~self.data[back].rec_reserved

take_possible = self.data_ready.bit_select(front, 1) & self.data[front].rec_full
take_vector = take_possible << front

self._elaborate(m, back, select_possible, take_vector)

with m.If(self.select.run):
m.d.sync += back.eq(back + 1)

with m.If(self.take.run):
m.d.sync += front.eq(front + 1)

return m
55 changes: 28 additions & 27 deletions coreblocks/func_blocks/fu/common/rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from coreblocks.interface.layouts import RSLayouts
from transactron.lib.metrics import HwExpHistogram, TaggedLatencyMeasurer
from transactron.utils import RecordDict
from transactron.utils import assign
from transactron.utils.assign import AssignType
from transactron.utils.amaranth_ext.functions import popcount
from transactron.utils.transactron_helpers import make_layout

__all__ = ["RS"]
__all__ = ["RSBase", "RS"]


class RS(Elaboratable):
class RSBase(Elaboratable):
def __init__(
self,
gen_params: GenParams,
Expand Down Expand Up @@ -57,34 +59,23 @@ def __init__(
sample_width=self.rs_entries_bits + 1,
)

def elaborate(self, platform):
m = TModule()

m.submodules.enc_select = PriorityEncoder(width=self.rs_entries)
def _elaborate(self, m: TModule, selected_id: Value, select_possible: Value, take_vector: Value):
m.submodules += [self.perf_rs_wait_time, self.perf_num_full]

for i, record in enumerate(self.data):
m.d.comb += self.data_ready[i].eq(
~record.rs_data.rp_s1.bool() & ~record.rs_data.rp_s2.bool() & record.rec_full.bool()
)

select_vector = Cat(~record.rec_reserved for record in self.data)
select_possible = select_vector.any()

take_vector = Cat(self.data_ready[i] & record.rec_full for i, record in enumerate(self.data))
take_possible = take_vector.any()

ready_lists: list[Value] = []
for op_list in self.ready_for:
op_vector = Cat(Cat(record.rs_data.exec_fn.op_type == op for op in op_list).any() for record in self.data)
ready_lists.append(take_vector & op_vector)

m.d.comb += m.submodules.enc_select.i.eq(select_vector)

@def_method(m, self.select, ready=select_possible)
def _() -> Signal:
m.d.sync += self.data[m.submodules.enc_select.o].rec_reserved.eq(1)
return m.submodules.enc_select.o
def _() -> RecordDict:
m.d.sync += self.data[selected_id].rec_reserved.eq(1)
return {"rs_entry_id": selected_id}

@def_method(m, self.insert)
def _(rs_entry_id: Value, rs_data: Value) -> None:
Expand All @@ -105,21 +96,15 @@ def _(reg_id: Value, reg_val: Value) -> None:
m.d.sync += record.rs_data.rp_s2.eq(0)
m.d.sync += record.rs_data.s2_val.eq(reg_val)

@def_method(m, self.take, ready=take_possible)
@def_method(m, self.take)
def _(rs_entry_id: Value) -> RecordDict:
record = self.data[rs_entry_id]
m.d.sync += record.rec_reserved.eq(0)
m.d.sync += record.rec_full.eq(0)
self.perf_rs_wait_time.stop(m, slot=rs_entry_id)
return {
"s1_val": record.rs_data.s1_val,
"s2_val": record.rs_data.s2_val,
"rp_dst": record.rs_data.rp_dst,
"rob_id": record.rs_data.rob_id,
"exec_fn": record.rs_data.exec_fn,
"imm": record.rs_data.imm,
"pc": record.rs_data.pc,
}
out = Signal(self.layouts.take_out)
m.d.av_comb += assign(out, record.rs_data, fields=AssignType.COMMON)
return out

for get_ready_list, ready_list in zip(self.get_ready_list, ready_lists):

Expand All @@ -133,4 +118,20 @@ def _() -> RecordDict:
with Transaction(name="perf").body(m):
self.perf_num_full.add(m, num_full)


class RS(RSBase):
def elaborate(self, platform):
m = TModule()

m.submodules.enc_select = enc_select = PriorityEncoder(width=self.rs_entries)

select_vector = Cat(~record.rec_reserved for record in self.data)
select_possible = select_vector.any()

take_vector = Cat(self.data_ready[i] & record.rec_full for i, record in enumerate(self.data))

m.d.comb += enc_select.i.eq(select_vector)

self._elaborate(m, enc_select.o, select_possible, take_vector)

return m
2 changes: 1 addition & 1 deletion coreblocks/scheduler/wakeup_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def elaborate(self, platform):
with Transaction().body(m):
ready = self.get_ready(m)
ready_width = ready.shape().size
last = Signal(range(ready_width))
last = Signal((ready_width - 1).bit_length())
for i in range(ready_width):
with m.If(ready.ready_list[i]):
m.d.comb += last.eq(i)
Expand Down
138 changes: 130 additions & 8 deletions test/structs_common/test_rs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import random
from collections import deque
from parameterized import parameterized_class

from amaranth.sim import Settle

from transactron.testing import TestCaseWithSimulator, get_outputs, SimpleTestCircuit

from coreblocks.func_blocks.fu.common.rs import RS
from coreblocks.func_blocks.fu.common.rs import RS, RSBase
from coreblocks.func_blocks.fu.common.fifo_rs import FifoRS
from coreblocks.params import *
from coreblocks.params.configurations import test_core_config
from coreblocks.frontend.decoder import OpType
Expand All @@ -20,6 +25,123 @@ def create_check_list(rs_entries_bits: int, insert_list: list[dict]) -> list[dic
return check_list


def create_data_list(gen_params: GenParams, count: int):
data_list = [
{
"rp_s1": random.randrange(1, 2**gen_params.phys_regs_bits) * random.randrange(2),
"rp_s2": random.randrange(1, 2**gen_params.phys_regs_bits) * random.randrange(2),
"rp_dst": random.randrange(2**gen_params.phys_regs_bits),
"rob_id": k,
"exec_fn": {
"op_type": 1,
"funct3": 2,
"funct7": 3,
},
"s1_val": k,
"s2_val": k,
"imm": k,
"pc": k,
}
for k in range(count)
]
return data_list


@parameterized_class(
("name", "rs_elaboratable"),
[
(
"RS",
RS,
),
(
"FifoRS",
FifoRS,
),
],
)
class TestRS(TestCaseWithSimulator):
rs_elaboratable: type[RSBase]

def test_rs(self):
random.seed(42)
self.gen_params = GenParams(test_core_config)
self.rs_entries_bits = self.gen_params.max_rs_entries_bits
self.m = SimpleTestCircuit(self.rs_elaboratable(self.gen_params, 2**self.rs_entries_bits, 0, None))
self.data_list = create_data_list(self.gen_params, 10 * 2**self.rs_entries_bits)
self.select_queue: deque[int] = deque()
self.regs_to_update: set[int] = set()
self.rs_entries: dict[int, int] = {}
self.finished = False

with self.run_simulation(self.m) as sim:
sim.add_sync_process(self.select_process)
sim.add_sync_process(self.insert_process)
sim.add_sync_process(self.update_process)
sim.add_sync_process(self.take_process)

def select_process(self):
for k in range(len(self.data_list)):
rs_entry_id = (yield from self.m.select.call())["rs_entry_id"]
self.select_queue.appendleft(rs_entry_id)
self.rs_entries[rs_entry_id] = k

def insert_process(self):
for data in self.data_list:
yield Settle() # so that select_process can insert into the queue
while not self.select_queue:
yield
yield Settle()
rs_entry_id = self.select_queue.pop()
yield from self.m.insert.call({"rs_entry_id": rs_entry_id, "rs_data": data})
if data["rp_s1"]:
self.regs_to_update.add(data["rp_s1"])
if data["rp_s2"]:
self.regs_to_update.add(data["rp_s2"])

def update_process(self):
while not self.finished:
yield Settle() # so that insert_process can insert into the set
if not self.regs_to_update:
yield
continue
reg_id = random.choice(list(self.regs_to_update))
self.regs_to_update.discard(reg_id)
reg_val = random.randrange(1000)
for rs_entry_id, k in self.rs_entries.items():
if self.data_list[k]["rp_s1"] == reg_id:
self.data_list[k]["rp_s1"] = 0
self.data_list[k]["s1_val"] = reg_val
if self.data_list[k]["rp_s2"] == reg_id:
self.data_list[k]["rp_s2"] = 0
self.data_list[k]["s2_val"] = reg_val
yield from self.m.update.call(reg_id=reg_id, reg_val=reg_val)

def take_process(self):
taken: set[int] = set()
yield from self.m.get_ready_list[0].call_init()
yield Settle()
for k in range(len(self.data_list)):
yield Settle()
while not (yield from self.m.get_ready_list[0].done()):
yield
ready_list = (yield from self.m.get_ready_list[0].call_result())["ready_list"]
possible_ids = [i for i in range(2**self.rs_entries_bits) if ready_list & (1 << i)]
if not possible_ids:
yield
continue
rs_entry_id = random.choice(possible_ids)
k = self.rs_entries[rs_entry_id]
taken.add(k)
test_data = dict(self.data_list[k])
del test_data["rp_s1"]
del test_data["rp_s2"]
data = yield from self.m.take.call(rs_entry_id=rs_entry_id)
self.assertEqual(data, test_data)
self.assertEqual(taken, set(range(len(self.data_list))))
self.finished = True


class TestRSMethodInsert(TestCaseWithSimulator):
def test_insert(self):
self.gen_params = GenParams(test_core_config)
Expand Down Expand Up @@ -261,24 +383,24 @@ def simulation_process(self):
self.assertEqual(expected, (yield from get_outputs(record)))

# Take first instruction
self.assertEqual((yield self.m._dut.take.ready), 1)
self.assertEqual((yield self.m._dut.get_ready_list[0].ready), 1)
data = yield from self.m.take.call(rs_entry_id=0)
for key in data:
self.assertEqual(data[key], self.check_list[0]["rs_data"][key])
yield Settle()
self.assertEqual((yield self.m._dut.take.ready), 0)
self.assertEqual((yield self.m._dut.get_ready_list[0].ready), 0)

# Update second instuction and take it
reg_id = 2
value_spx = 1
yield from self.m.update.call(reg_id=reg_id, reg_val=value_spx)
yield Settle()
self.assertEqual((yield self.m._dut.take.ready), 1)
self.assertEqual((yield self.m._dut.get_ready_list[0].ready), 1)
data = yield from self.m.take.call(rs_entry_id=1)
for key in data:
self.assertEqual(data[key], self.check_list[1]["rs_data"][key])
yield Settle()
self.assertEqual((yield self.m._dut.take.ready), 0)
self.assertEqual((yield self.m._dut.get_ready_list[0].ready), 0)

# Insert two new ready instructions and take them
reg_id = 0
Expand All @@ -302,20 +424,20 @@ def simulation_process(self):
for index in range(2):
yield from self.m.insert.call(rs_entry_id=index, rs_data=entry_data)
yield Settle()
self.assertEqual((yield self.m._dut.take.ready), 1)
self.assertEqual((yield self.m._dut.get_ready_list[0].ready), 1)
self.assertEqual((yield self.m._dut.data_ready[index]), 1)

data = yield from self.m.take.call(rs_entry_id=0)
for key in data:
self.assertEqual(data[key], entry_data[key])
yield Settle()
self.assertEqual((yield self.m._dut.take.ready), 1)
self.assertEqual((yield self.m._dut.get_ready_list[0].ready), 1)

data = yield from self.m.take.call(rs_entry_id=1)
for key in data:
self.assertEqual(data[key], entry_data[key])
yield Settle()
self.assertEqual((yield self.m._dut.take.ready), 0)
self.assertEqual((yield self.m._dut.get_ready_list[0].ready), 0)


class TestRSMethodGetReadyList(TestCaseWithSimulator):
Expand Down

0 comments on commit 8c6128b

Please sign in to comment.