diff --git a/coreblocks/core.py b/coreblocks/core.py index 1cec20059..2adcbbe53 100644 --- a/coreblocks/core.py +++ b/coreblocks/core.py @@ -75,8 +75,8 @@ def __init__(self, *, gen_params: GenParams, wb_instr_bus: WishboneBus, wb_data_ gen=self.gen_params, get_result=self.func_blocks_unifier.get_result, rob_mark_done=self.ROB.mark_done, - rs_write_val=self.func_blocks_unifier.update, - rf_write_val=self.RF.write, + rs_update=self.func_blocks_unifier.update, + rf_write=self.RF.write, ) self.csr_generic = GenericCSRRegisters(self.gen_params) diff --git a/coreblocks/frontend/decode.py b/coreblocks/frontend/decode.py index e6513e158..aefd00d5f 100644 --- a/coreblocks/frontend/decode.py +++ b/coreblocks/frontend/decode.py @@ -42,7 +42,7 @@ def elaborate(self, platform): with Transaction().body(m): raw = self.get_raw(m) - m.d.top_comb += instr_decoder.instr.eq(raw.data) + m.d.top_comb += instr_decoder.instr.eq(raw.instr) # Jump-branch unit requires information if the instruction was # decoded from a compressed instruction. To avoid adding a new signal diff --git a/coreblocks/frontend/fetch.py b/coreblocks/frontend/fetch.py index 74c3861a1..96e0da3af 100644 --- a/coreblocks/frontend/fetch.py +++ b/coreblocks/frontend/fetch.py @@ -82,7 +82,7 @@ def stall(): m.d.sync += self.pc.eq(target.addr) m.d.comb += instr.eq(res.instr) - self.cont(m, data=instr, pc=target.addr, access_fault=fetch_error, rvc=0) + self.cont(m, instr=instr, pc=target.addr, access_fault=fetch_error, rvc=0) @def_method(m, self.verify_branch, ready=stalled) def _(from_pc: Value, next_pc: Value): @@ -210,7 +210,7 @@ def elaborate(self, platform) -> TModule: with m.If(~cache_resp.error): m.d.sync += current_pc.eq(current_pc + Mux(is_rvc, C(2, 3), C(4, 3))) - self.cont(m, data=instr, pc=current_pc, access_fault=cache_resp.error, rvc=is_rvc) + self.cont(m, instr=instr, pc=current_pc, access_fault=cache_resp.error, rvc=is_rvc) @def_method(m, self.verify_branch, ready=(stalled & ~flushing)) def _(from_pc: Value, next_pc: Value): diff --git a/coreblocks/fu/fu_decoder.py b/coreblocks/fu/fu_decoder.py index 0e3b7939a..510ee30f0 100644 --- a/coreblocks/fu/fu_decoder.py +++ b/coreblocks/fu/fu_decoder.py @@ -1,7 +1,7 @@ from typing import Sequence, Type from amaranth import * -from coreblocks.params import GenParams, CommonLayouts +from coreblocks.params import GenParams, CommonLayoutFields from enum import IntFlag @@ -19,9 +19,9 @@ class Decoder(Elaboratable): """ def __init__(self, gen_params: GenParams, decode_fn: Type[IntFlag], ops: Sequence[tuple], check_optype: bool): - layouts = gen_params.get(CommonLayouts) + layouts = gen_params.get(CommonLayoutFields) - self.exec_fn = Record(layouts.exec_fn) + self.exec_fn = Record(layouts.exec_fn_layout) self.decode_fn = Signal(decode_fn) self.ops = ops self.check_optype = check_optype diff --git a/coreblocks/lsu/dummyLsu.py b/coreblocks/lsu/dummyLsu.py index cb506aafc..d1530dc4c 100644 --- a/coreblocks/lsu/dummyLsu.py +++ b/coreblocks/lsu/dummyLsu.py @@ -204,9 +204,9 @@ def __init__(self, gen_params: GenParams, bus: WishboneMaster) -> None: self.fu_layouts = gen_params.get(FuncUnitLayouts) self.lsu_layouts = gen_params.get(LSULayouts) - self.insert = Method(i=self.lsu_layouts.rs_insert_in) - self.select = Method(o=self.lsu_layouts.rs_select_out) - self.update = Method(i=self.lsu_layouts.rs_update_in) + self.insert = Method(i=self.lsu_layouts.rs.insert_in) + self.select = Method(o=self.lsu_layouts.rs.select_out) + self.update = Method(i=self.lsu_layouts.rs.update_in) self.get_result = Method(o=self.fu_layouts.accept) self.precommit = Method(i=self.lsu_layouts.precommit) @@ -215,7 +215,7 @@ def __init__(self, gen_params: GenParams, bus: WishboneMaster) -> None: def elaborate(self, platform): m = TModule() reserved = Signal() # means that current_instr is reserved - current_instr = Record(self.lsu_layouts.rs_data_layout + [("valid", 1)]) + current_instr = Record(self.lsu_layouts.rs.data_layout + [("valid", 1)]) m.submodules.internal = internal = LSUDummyInternals(self.gen_params, self.bus, current_instr) @@ -233,12 +233,12 @@ def _(rs_data: Record, rs_entry_id: Value): m.d.sync += current_instr.valid.eq(1) @def_method(m, self.update) - def _(tag: Value, value: Value): - with m.If(current_instr.rp_s1 == tag): - m.d.sync += current_instr.s1_val.eq(value) + def _(reg_id: Value, reg_val: Value): + with m.If(current_instr.rp_s1 == reg_id): + m.d.sync += current_instr.s1_val.eq(reg_val) m.d.sync += current_instr.rp_s1.eq(0) - with m.If(current_instr.rp_s2 == tag): - m.d.sync += current_instr.s2_val.eq(value) + with m.If(current_instr.rp_s2 == reg_id): + m.d.sync += current_instr.s2_val.eq(reg_val) m.d.sync += current_instr.rp_s2.eq(0) @def_method(m, self.get_result, result_ready) diff --git a/coreblocks/params/genparams.py b/coreblocks/params/genparams.py index 2b6611aa9..9bf84c23c 100644 --- a/coreblocks/params/genparams.py +++ b/coreblocks/params/genparams.py @@ -7,6 +7,7 @@ from .icache_params import ICacheParameters from .fu_params import extensions_supported from ..peripherals.wishbone import WishboneParameters +from transactron.utils import make_hashable from typing import TYPE_CHECKING @@ -35,10 +36,11 @@ class DependentCache: """ def __init__(self): - self._depcache: dict[tuple[Type, frozenset[tuple[str, Any]]], Type] = {} + self._depcache: dict[tuple[Type, Any], Type] = {} def get(self, cls: Type[T], **kwargs) -> T: - v = self._depcache.get((cls, frozenset(kwargs.items())), None) + cache_key = make_hashable(kwargs) + v = self._depcache.get((cls, cache_key), None) if v is None: positional_count = cls.__init__.__code__.co_argcount @@ -50,7 +52,7 @@ def get(self, cls: Type[T], **kwargs) -> T: v = cls(self, **kwargs) else: v = cls(**kwargs) - self._depcache[(cls, frozenset(kwargs.items()))] = v + self._depcache[(cls, cache_key)] = v return v diff --git a/coreblocks/params/layouts.py b/coreblocks/params/layouts.py index 2a0f1a3e0..6df1970b7 100644 --- a/coreblocks/params/layouts.py +++ b/coreblocks/params/layouts.py @@ -1,11 +1,12 @@ from coreblocks.params import GenParams, OpType, Funct7, Funct3 from coreblocks.params.isa import ExceptionCause from transactron.utils.utils import layout_subset +from transactron.utils import LayoutList, LayoutListField __all__ = [ + "CommonLayoutFields", "SchedulerLayouts", "ROBLayouts", - "CommonLayouts", "FetchLayouts", "DecodeLayouts", "FuncUnitLayouts", @@ -20,170 +21,324 @@ ] -class CommonLayouts: +class CommonLayoutFields: + """Commonly used layout fields.""" + def __init__(self, gen_params: GenParams): - self.exec_fn = [ - ("op_type", OpType), - ("funct3", Funct3), - ("funct7", Funct7), - ] + self.op_type: LayoutListField = ("op_type", OpType) + """Decoded operation type.""" - self.regs_l = [ - ("rl_s1", gen_params.isa.reg_cnt_log), - ("rl_s2", gen_params.isa.reg_cnt_log), - ("rl_dst", gen_params.isa.reg_cnt_log), - ] + self.funct3: LayoutListField = ("funct3", Funct3) + """RISC V funct3 value.""" - self.regs_p = [ - ("rp_dst", gen_params.phys_regs_bits), - ("rp_s1", gen_params.phys_regs_bits), - ("rp_s2", gen_params.phys_regs_bits), - ] + self.funct7: LayoutListField = ("funct7", Funct7) + """RISC V funct7 value.""" + + self.rl_s1: LayoutListField = ("rl_s1", gen_params.isa.reg_cnt_log) + """Logical register number of first source operand.""" + + self.rl_s2: LayoutListField = ("rl_s2", gen_params.isa.reg_cnt_log) + """Logical register number of second source operand.""" + + self.rl_dst: LayoutListField = ("rl_dst", gen_params.isa.reg_cnt_log) + """Logical register number of destination operand.""" + + self.rp_s1: LayoutListField = ("rp_s1", gen_params.phys_regs_bits) + """Physical register number of first source operand.""" + + self.rp_s2: LayoutListField = ("rp_s2", gen_params.phys_regs_bits) + """Physical register number of second source operand.""" + + self.rp_dst: LayoutListField = ("rp_dst", gen_params.phys_regs_bits) + """Physical register number of destination operand.""" + + self.imm: LayoutListField = ("imm", gen_params.isa.xlen) + """Immediate value.""" + + self.csr: LayoutListField = ("csr", gen_params.isa.csr_alen) + """CSR number.""" + + self.pc: LayoutListField = ("pc", gen_params.isa.xlen) + """Program counter value.""" + + self.rob_id: LayoutListField = ("rob_id", gen_params.rob_entries_bits) + """Reorder buffer entry identifier.""" + + self.s1_val: LayoutListField = ("s1_val", gen_params.isa.xlen) + """Value of first source operand.""" + + self.s2_val: LayoutListField = ("s2_val", gen_params.isa.xlen) + """Value of second source operand.""" + + self.reg_val: LayoutListField = ("reg_val", gen_params.isa.xlen) + """Value of some physical register.""" + + self.addr: LayoutListField = ("addr", gen_params.isa.xlen) + """Memory address.""" + + self.data: LayoutListField = ("data", gen_params.isa.xlen) + """Piece of data.""" + + self.instr: LayoutListField = ("instr", gen_params.isa.ilen) + """RISC V instruction.""" + + self.exec_fn_layout: LayoutList = [self.op_type, self.funct3, self.funct7] + """Decoded instruction, in layout form.""" + + self.exec_fn: LayoutListField = ("exec_fn", self.exec_fn_layout) + """Decoded instruction.""" + + self.regs_l: LayoutListField = ("regs_l", [self.rl_s1, self.rl_s2, self.rl_dst]) + """Logical register numbers - as described in the RISC V manual. They index the RATs.""" + + self.regs_p: LayoutListField = ("regs_p", [self.rp_s1, self.rp_s2, self.rp_dst]) + """Physical register numbers. They index the register file.""" + + self.reg_id: LayoutListField = ("reg_id", gen_params.phys_regs_bits) + """Physical register ID.""" + + self.exception: LayoutListField = ("exception", 1) + """Exception is raised for this instruction.""" + + self.error: LayoutListField = ("error", 1) + """Request ended with an error.""" class SchedulerLayouts: + """Layouts used in the scheduler.""" + def __init__(self, gen_params: GenParams): - common = gen_params.get(CommonLayouts) - self.reg_alloc_in = [ - ("exec_fn", common.exec_fn), - ("regs_l", common.regs_l), - ("imm", gen_params.isa.xlen), - ("csr", gen_params.isa.csr_alen), - ("pc", gen_params.isa.xlen), + fields = gen_params.get(CommonLayoutFields) + + self.rs_selected: LayoutListField = ("rs_selected", gen_params.rs_number_bits) + """Reservation Station number for the instruction.""" + + self.rs_entry_id: LayoutListField = ("rs_entry_id", gen_params.max_rs_entries_bits) + """Reservation station entry ID for the instruction.""" + + self.regs_p_alloc_out: LayoutListField = ("regs_p", [fields.rp_dst]) + """Physical register number for the destination operand, after allocation.""" + + self.regs_l_rob_in: LayoutListField = ( + "regs_l", + [ + fields.rl_dst, + ("rl_dst_v", 1), + ], + ) + """Logical register number for the destination operand, before ROB allocation.""" + + self.reg_alloc_in: LayoutList = [ + fields.exec_fn, + fields.regs_l, + fields.imm, + fields.csr, + fields.pc, ] - self.reg_alloc_out = self.renaming_in = [ - ("exec_fn", common.exec_fn), - ("regs_l", common.regs_l), - ("regs_p", [("rp_dst", gen_params.phys_regs_bits)]), - ("imm", gen_params.isa.xlen), - ("csr", gen_params.isa.csr_alen), - ("pc", gen_params.isa.xlen), + + self.reg_alloc_out: LayoutList = [ + fields.exec_fn, + fields.regs_l, + self.regs_p_alloc_out, + fields.imm, + fields.csr, + fields.pc, ] - self.renaming_out = self.rob_allocate_in = [ - ("exec_fn", common.exec_fn), - ( - "regs_l", - [ - ("rl_dst", gen_params.isa.reg_cnt_log), - ("rl_dst_v", 1), - ], - ), - ("regs_p", common.regs_p), - ("imm", gen_params.isa.xlen), - ("csr", gen_params.isa.csr_alen), - ("pc", gen_params.isa.xlen), + + self.renaming_in = self.reg_alloc_out + + self.renaming_out: LayoutList = [ + fields.exec_fn, + self.regs_l_rob_in, + fields.regs_p, + fields.imm, + fields.csr, + fields.pc, ] - self.rob_allocate_out = self.rs_select_in = [ - ("exec_fn", common.exec_fn), - ("regs_p", common.regs_p), - ("rob_id", gen_params.rob_entries_bits), - ("imm", gen_params.isa.xlen), - ("csr", gen_params.isa.csr_alen), - ("pc", gen_params.isa.xlen), + + self.rob_allocate_in = self.renaming_out + + self.rob_allocate_out: LayoutList = [ + fields.exec_fn, + fields.regs_p, + fields.rob_id, + fields.imm, + fields.csr, + fields.pc, ] - self.rs_select_out = self.rs_insert_in = [ - ("exec_fn", common.exec_fn), - ("regs_p", common.regs_p), - ("rob_id", gen_params.rob_entries_bits), - ("rs_selected", gen_params.rs_number_bits), - ("rs_entry_id", gen_params.max_rs_entries_bits), - ("imm", gen_params.isa.xlen), - ("csr", gen_params.isa.csr_alen), - ("pc", gen_params.isa.xlen), + + self.rs_select_in = self.rob_allocate_out + + self.rs_select_out: LayoutList = [ + fields.exec_fn, + fields.regs_p, + fields.rob_id, + self.rs_selected, + self.rs_entry_id, + fields.imm, + fields.csr, + fields.pc, ] - self.free_rf_layout = [("reg_id", gen_params.phys_regs_bits)] + + self.rs_insert_in = self.rs_select_out + + self.free_rf_layout: LayoutList = [fields.reg_id] class RFLayouts: + """Layouts used in the register file.""" + def __init__(self, gen_params: GenParams): - self.rf_read_in = self.rf_free = [("reg_id", gen_params.phys_regs_bits)] - self.rf_read_out = [("reg_val", gen_params.isa.xlen), ("valid", 1)] - self.rf_write = [("reg_id", gen_params.phys_regs_bits), ("reg_val", gen_params.isa.xlen)] + fields = gen_params.get(CommonLayoutFields) + + self.valid: LayoutListField = ("valid", 1) + """Physical register was assigned a value.""" + + self.rf_read_in: LayoutList = [fields.reg_id] + self.rf_free: LayoutList = [fields.reg_id] + self.rf_read_out: LayoutList = [fields.reg_val, self.valid] + self.rf_write: LayoutList = [fields.reg_id, fields.reg_val] class RATLayouts: + """Layouts used in the register alias tables.""" + def __init__(self, gen_params: GenParams): - self.rat_rename_in = [ - ("rl_s1", gen_params.isa.reg_cnt_log), - ("rl_s2", gen_params.isa.reg_cnt_log), - ("rl_dst", gen_params.isa.reg_cnt_log), - ("rp_dst", gen_params.phys_regs_bits), + fields = gen_params.get(CommonLayoutFields) + + self.old_rp_dst: LayoutListField = ("old_rp_dst", gen_params.phys_regs_bits) + """Physical register previously associated with the given logical register in RRAT.""" + + self.rat_rename_in: LayoutList = [ + fields.rl_s1, + fields.rl_s2, + fields.rl_dst, + fields.rp_dst, ] - self.rat_rename_out = [("rp_s1", gen_params.phys_regs_bits), ("rp_s2", gen_params.phys_regs_bits)] - self.rat_commit_in = [("rl_dst", gen_params.isa.reg_cnt_log), ("rp_dst", gen_params.phys_regs_bits)] - self.rat_commit_out = [("old_rp_dst", gen_params.phys_regs_bits)] + self.rat_rename_out: LayoutList = [fields.rp_s1, fields.rp_s2] + + self.rat_commit_in: LayoutList = [fields.rl_dst, fields.rp_dst] + + self.rat_commit_out: LayoutList = [self.old_rp_dst] class ROBLayouts: + """Layouts used in the reorder buffer.""" + def __init__(self, gen_params: GenParams): - self.data_layout = [ - ("rl_dst", gen_params.isa.reg_cnt_log), - ("rp_dst", gen_params.phys_regs_bits), - ] + fields = gen_params.get(CommonLayoutFields) - self.id_layout = [ - ("rob_id", gen_params.rob_entries_bits), + self.data_layout: LayoutList = [ + fields.rl_dst, + fields.rp_dst, ] - self.internal_layout = [ - ("rob_data", self.data_layout), - ("done", 1), - ("exception", 1), + self.rob_data: LayoutListField = ("rob_data", self.data_layout) + """Data stored in a reorder buffer entry.""" + + self.done: LayoutListField = ("done", 1) + """Instruction has executed, but is not committed yet.""" + + self.start: LayoutListField = ("start", gen_params.rob_entries_bits) + """Index of the first (the earliest) entry in the reorder buffer.""" + + self.end: LayoutListField = ("end", gen_params.rob_entries_bits) + """Index of the entry following the last (the latest) entry in the reorder buffer.""" + + self.id_layout: LayoutList = [fields.rob_id] + + self.internal_layout: LayoutList = [ + self.rob_data, + self.done, + fields.exception, ] - self.mark_done_layout = [ - ("rob_id", gen_params.rob_entries_bits), - ("exception", 1), + self.mark_done_layout: LayoutList = [ + fields.rob_id, + fields.exception, ] - self.peek_layout = self.retire_layout = [ - ("rob_data", self.data_layout), - ("rob_id", gen_params.rob_entries_bits), - ("exception", 1), + self.peek_layout: LayoutList = [ + self.rob_data, + fields.rob_id, + fields.exception, ] - self.get_indices = [("start", gen_params.rob_entries_bits), ("end", gen_params.rob_entries_bits)] + self.retire_layout: LayoutList = self.peek_layout + self.get_indices: LayoutList = [self.start, self.end] -class RSInterfaceLayouts: - def __init__(self, gen_params: GenParams, *, rs_entries_bits: int): - common = gen_params.get(CommonLayouts) - self.data_layout = [ - ("rp_s1", gen_params.phys_regs_bits), - ("rp_s2", gen_params.phys_regs_bits), + +class RSLayoutFields: + """Layout fields used in the reservation station.""" + + def __init__(self, gen_params: GenParams, *, rs_entries_bits: int, data_layout: LayoutList): + self.rs_data: LayoutListField = ("rs_data", data_layout) + """Data about an instuction stored in a reservation station (RS).""" + + self.rs_entry_id: LayoutListField = ("rs_entry_id", rs_entries_bits) + """Index in a reservation station (RS).""" + + +class RSFullDataLayout: + """Full data layout for functional blocks. Blocks can use a subset.""" + + def __init__(self, gen_params: GenParams): + fields = gen_params.get(CommonLayoutFields) + + self.data_layout: LayoutList = [ + fields.rp_s1, + fields.rp_s2, ("rp_s1_reg", gen_params.phys_regs_bits), ("rp_s2_reg", gen_params.phys_regs_bits), - ("rp_dst", gen_params.phys_regs_bits), - ("rob_id", gen_params.rob_entries_bits), - ("exec_fn", common.exec_fn), - ("s1_val", gen_params.isa.xlen), - ("s2_val", gen_params.isa.xlen), - ("imm", gen_params.isa.xlen), - ("csr", gen_params.isa.csr_alen), - ("pc", gen_params.isa.xlen), + fields.rp_dst, + fields.rob_id, + fields.exec_fn, + fields.s1_val, + fields.s2_val, + fields.imm, + fields.csr, + fields.pc, ] - self.select_out = [("rs_entry_id", rs_entries_bits)] - self.insert_in = [("rs_data", self.data_layout), ("rs_entry_id", rs_entries_bits)] +class RSInterfaceLayouts: + """Layouts used in functional blocks.""" + + def __init__(self, gen_params: GenParams, *, rs_entries_bits: int, data_layout: LayoutList): + fields = gen_params.get(CommonLayoutFields) + rs_fields = gen_params.get(RSLayoutFields, rs_entries_bits=rs_entries_bits, data_layout=data_layout) + + self.data_layout: LayoutList = data_layout - self.update_in = [("tag", gen_params.phys_regs_bits), ("value", gen_params.isa.xlen)] + self.select_out: LayoutList = [rs_fields.rs_entry_id] + + self.insert_in: LayoutList = [rs_fields.rs_data, rs_fields.rs_entry_id] + + self.update_in: LayoutList = [fields.reg_id, fields.reg_val] class RetirementLayouts: + """Layouts used in the retirement module.""" + def __init__(self, gen_params: GenParams): - self.precommit = [ - ("rob_id", gen_params.rob_entries_bits), - ] + fields = gen_params.get(CommonLayoutFields) + + self.precommit: LayoutList = [fields.rob_id] class RSLayouts: + """Layouts used in the reservation station.""" + def __init__(self, gen_params: GenParams, *, rs_entries_bits: int): - rs_interface = gen_params.get(RSInterfaceLayouts, rs_entries_bits=rs_entries_bits) + data = gen_params.get(RSFullDataLayout) - self.data_layout = layout_subset( - rs_interface.data_layout, + self.ready_list: LayoutListField = ("ready_list", 2**rs_entries_bits) + """Bitmask of reservation station entries containing instructions which are ready to run.""" + + data_layout = layout_subset( + data.data_layout, fields={ "rp_s1", "rp_s2", @@ -197,16 +352,13 @@ def __init__(self, gen_params: GenParams, *, rs_entries_bits: int): }, ) - self.insert_in = [("rs_data", self.data_layout), ("rs_entry_id", rs_entries_bits)] - - self.select_out = rs_interface.select_out + self.rs = gen_params.get(RSInterfaceLayouts, rs_entries_bits=rs_entries_bits, data_layout=data_layout) + rs_fields = gen_params.get(RSLayoutFields, rs_entries_bits=rs_entries_bits, data_layout=data_layout) - self.update_in = rs_interface.update_in - - self.take_in = [("rs_entry_id", rs_entries_bits)] + self.take_in: LayoutList = [rs_fields.rs_entry_id] self.take_out = layout_subset( - rs_interface.data_layout, + data.data_layout, fields={ "s1_val", "s2_val", @@ -218,113 +370,137 @@ def __init__(self, gen_params: GenParams, *, rs_entries_bits: int): }, ) - self.get_ready_list_out = [("ready_list", 2**rs_entries_bits)] + self.get_ready_list_out: LayoutList = [self.ready_list] class ICacheLayouts: + """Layouts used in the instruction cache.""" + def __init__(self, gen_params: GenParams): - self.issue_req = [ - ("addr", gen_params.isa.xlen), - ] + fields = gen_params.get(CommonLayoutFields) + + self.error: LayoutListField = ("last", 1) + """This is the last cache refill result.""" + + self.issue_req: LayoutList = [fields.addr] - self.accept_res = [ - ("instr", gen_params.isa.ilen), - ("error", 1), + self.accept_res: LayoutList = [ + fields.instr, + fields.error, ] - self.start_refill = [ - ("addr", gen_params.isa.xlen), + self.start_refill: LayoutList = [ + fields.addr, ] - self.accept_refill = [ - ("addr", gen_params.isa.xlen), - ("data", gen_params.isa.xlen), - ("error", 1), - ("last", 1), + self.accept_refill: LayoutList = [ + fields.addr, + fields.data, + fields.error, + self.error, ] class FetchLayouts: + """Layouts used in the fetcher.""" + def __init__(self, gen_params: GenParams): - self.raw_instr = [ - ("data", gen_params.isa.ilen), - ("pc", gen_params.isa.xlen), - ("access_fault", 1), - ("rvc", 1), + fields = gen_params.get(CommonLayoutFields) + + self.access_fault: LayoutListField = ("access_fault", 1) + """Instruction fetch failed.""" + + self.rvc: LayoutListField = ("rvc", 1) + """Instruction is a compressed (two-byte) one.""" + + self.raw_instr: LayoutList = [ + fields.instr, + fields.pc, + self.access_fault, + self.rvc, ] - self.branch_verify = [ + self.branch_verify: LayoutList = [ ("from_pc", gen_params.isa.xlen), ("next_pc", gen_params.isa.xlen), ] class DecodeLayouts: + """Layouts used in the decoder.""" + def __init__(self, gen_params: GenParams): - common = gen_params.get(CommonLayouts) - self.decoded_instr = [ - ("exec_fn", common.exec_fn), - ("regs_l", common.regs_l), - ("imm", gen_params.isa.xlen), - ("csr", gen_params.isa.csr_alen), - ("pc", gen_params.isa.xlen), + fields = gen_params.get(CommonLayoutFields) + + self.decoded_instr: LayoutList = [ + fields.exec_fn, + fields.regs_l, + fields.imm, + fields.csr, + fields.pc, ] class FuncUnitLayouts: + """Layouts used in functional units.""" + def __init__(self, gen_params: GenParams): - common = gen_params.get(CommonLayouts) - - self.issue = [ - ("s1_val", gen_params.isa.xlen), - ("s2_val", gen_params.isa.xlen), - ("rp_dst", gen_params.phys_regs_bits), - ("rob_id", gen_params.rob_entries_bits), - ("exec_fn", common.exec_fn), - ("imm", gen_params.isa.xlen), - ("pc", gen_params.isa.xlen), + fields = gen_params.get(CommonLayoutFields) + + self.result: LayoutListField = ("result", gen_params.isa.xlen) + """The result value produced in a functional unit.""" + + self.issue: LayoutList = [ + fields.s1_val, + fields.s2_val, + fields.rp_dst, + fields.rob_id, + fields.exec_fn, + fields.imm, + fields.pc, ] - self.accept = [ - ("rob_id", gen_params.rob_entries_bits), - ("result", gen_params.isa.xlen), - ("rp_dst", gen_params.phys_regs_bits), - ("exception", 1), + self.accept: LayoutList = [ + fields.rob_id, + self.result, + fields.rp_dst, + fields.exception, ] class UnsignedMulUnitLayouts: def __init__(self, gen_params: GenParams): - self.issue = [ + self.issue: LayoutList = [ ("i1", gen_params.isa.xlen), ("i2", gen_params.isa.xlen), ] - self.accept = [ + self.accept: LayoutList = [ ("o", 2 * gen_params.isa.xlen), ] class DivUnitLayouts: def __init__(self, gen: GenParams): - self.issue = [ + self.issue: LayoutList = [ ("dividend", gen.isa.xlen), ("divisor", gen.isa.xlen), ] - self.accept = [ + self.accept: LayoutList = [ ("quotient", gen.isa.xlen), ("remainder", gen.isa.xlen), ] class LSULayouts: + """Layouts used in the load-store unit.""" + def __init__(self, gen_params: GenParams): - self.rs_entries_bits = 0 + data = gen_params.get(RSFullDataLayout) - rs_interface = gen_params.get(RSInterfaceLayouts, rs_entries_bits=self.rs_entries_bits) - self.rs_data_layout = layout_subset( - rs_interface.data_layout, + data_layout = layout_subset( + data.data_layout, fields={ "rp_s1", "rp_s2", @@ -337,11 +513,9 @@ def __init__(self, gen_params: GenParams): }, ) - self.rs_insert_in = [("rs_data", self.rs_data_layout), ("rs_entry_id", self.rs_entries_bits)] - - self.rs_select_out = rs_interface.select_out + self.rs_entries_bits = 0 - self.rs_update_in = rs_interface.update_in + self.rs = gen_params.get(RSInterfaceLayouts, rs_entries_bits=self.rs_entries_bits, data_layout=data_layout) retirement = gen_params.get(RetirementLayouts) @@ -349,23 +523,27 @@ def __init__(self, gen_params: GenParams): class CSRLayouts: + """Layouts used in the control and status registers.""" + def __init__(self, gen_params: GenParams): + data = gen_params.get(RSFullDataLayout) + fields = gen_params.get(CommonLayoutFields) + self.rs_entries_bits = 0 - self.read = [ - ("data", gen_params.isa.xlen), + self.read: LayoutList = [ + fields.data, ("read", 1), ("written", 1), ] - self.write = [("data", gen_params.isa.xlen)] + self.write: LayoutList = [fields.data] - self._fu_read = [("data", gen_params.isa.xlen)] - self._fu_write = [("data", gen_params.isa.xlen)] + self._fu_read: LayoutList = [fields.data] + self._fu_write: LayoutList = [fields.data] - rs_interface = gen_params.get(RSInterfaceLayouts, rs_entries_bits=self.rs_entries_bits) - self.rs_data_layout = layout_subset( - rs_interface.data_layout, + data_layout = layout_subset( + data.data_layout, fields={ "rp_s1", "rp_s1_reg", @@ -379,11 +557,7 @@ def __init__(self, gen_params: GenParams): }, ) - self.rs_insert_in = [("rs_data", self.rs_data_layout), ("rs_entry_id", self.rs_entries_bits)] - - self.rs_select_out = rs_interface.select_out - - self.rs_update_in = rs_interface.update_in + self.rs = gen_params.get(RSInterfaceLayouts, rs_entries_bits=self.rs_entries_bits, data_layout=data_layout) retirement = gen_params.get(RetirementLayouts) @@ -391,8 +565,17 @@ def __init__(self, gen_params: GenParams): class ExceptionRegisterLayouts: + """Layouts used in the exception register.""" + def __init__(self, gen_params: GenParams): - self.get = self.report = [ - ("cause", ExceptionCause), - ("rob_id", gen_params.rob_entries_bits), + fields = gen_params.get(CommonLayoutFields) + + self.cause: LayoutListField = ("cause", ExceptionCause) + """Exception cause.""" + + self.get: LayoutList = [ + self.cause, + fields.rob_id, ] + + self.report = self.get diff --git a/coreblocks/stages/backend.py b/coreblocks/stages/backend.py index 4ba245baf..00bfb7940 100644 --- a/coreblocks/stages/backend.py +++ b/coreblocks/stages/backend.py @@ -20,7 +20,7 @@ class ResultAnnouncement(Elaboratable): """ def __init__( - self, *, gen: GenParams, get_result: Method, rob_mark_done: Method, rs_write_val: Method, rf_write_val: Method + self, *, gen: GenParams, get_result: Method, rob_mark_done: Method, rs_update: Method, rf_write: Method ): """ Parameters @@ -34,20 +34,17 @@ def __init__( from different FUs are already serialized. rob_mark_done : Method Method which is invoked to mark that instruction ended without exception. - It uses layout with one field `rob_id`, - rs_write_val : Method + rs_update : Method Method which is invoked to pass value which is an output of finished instruction to RS, so that RS can save it if there are instructions which wait for it. - It uses layout with two fields `tag` and `value`. - rf_write_val : Method + rf_write : Method Method which is invoked to save value which is an output of finished instruction to RF. - It uses layout with two fields `reg_id` and `reg_val`. """ self.m_get_result = get_result self.m_rob_mark_done = rob_mark_done - self.m_rs_write_val = rs_write_val - self.m_rf_write_val = rf_write_val + self.m_rs_update = rs_update + self.m_rf_write_val = rf_write def debug_signals(self): return [self.m_get_result.debug_signals()] @@ -62,6 +59,6 @@ def elaborate(self, platform): with m.If(result.exception == 0): self.m_rf_write_val(m, reg_id=result.rp_dst, reg_val=result.result) with m.If(result.rp_dst != 0): - self.m_rs_write_val(m, tag=result.rp_dst, value=result.result) + self.m_rs_update(m, reg_id=result.rp_dst, reg_val=result.result) return m diff --git a/coreblocks/stages/rs_func_block.py b/coreblocks/stages/rs_func_block.py index e2ff47983..9b3a45c4b 100644 --- a/coreblocks/stages/rs_func_block.py +++ b/coreblocks/stages/rs_func_block.py @@ -47,9 +47,9 @@ def __init__(self, gen_params: GenParams, func_units: Iterable[tuple[FuncUnit, s self.fu_layouts = gen_params.get(FuncUnitLayouts) self.func_units = list(func_units) - self.insert = Method(i=self.rs_layouts.insert_in) - self.select = Method(o=self.rs_layouts.select_out) - self.update = Method(i=self.rs_layouts.update_in) + self.insert = Method(i=self.rs_layouts.rs.insert_in) + self.select = Method(o=self.rs_layouts.rs.select_out) + self.update = Method(i=self.rs_layouts.rs.update_in) self.get_result = Method(o=self.fu_layouts.accept) def elaborate(self, platform): diff --git a/coreblocks/structs_common/csr.py b/coreblocks/structs_common/csr.py index 3dc4763d6..f1a6d89f1 100644 --- a/coreblocks/structs_common/csr.py +++ b/coreblocks/structs_common/csr.py @@ -194,9 +194,9 @@ def __init__(self, gen_params: GenParams): # Standard RS interface self.csr_layouts = gen_params.get(CSRLayouts) self.fu_layouts = gen_params.get(FuncUnitLayouts) - self.select = Method(o=self.csr_layouts.rs_select_out) - self.insert = Method(i=self.csr_layouts.rs_insert_in) - self.update = Method(i=self.csr_layouts.rs_update_in) + self.select = Method(o=self.csr_layouts.rs.select_out) + self.insert = Method(i=self.csr_layouts.rs.insert_in) + self.update = Method(i=self.csr_layouts.rs.update_in) self.get_result = Method(o=self.fu_layouts.accept) self.precommit = Method(i=self.csr_layouts.precommit) @@ -223,7 +223,7 @@ def elaborate(self, platform): current_result = Signal(self.gen_params.isa.xlen) - instr = Record(self.csr_layouts.rs_data_layout + [("valid", 1)]) + instr = Record(self.csr_layouts.rs.data_layout + [("valid", 1)]) m.d.comb += ready_to_process.eq(rob_sfx_empty & instr.valid & (instr.rp_s1 == 0)) @@ -310,9 +310,9 @@ def _(rs_entry_id, rs_data): m.d.sync += instr.valid.eq(1) @def_method(m, self.update) - def _(tag, value): - with m.If(tag == instr.rp_s1): - m.d.sync += instr.s1_val.eq(value) + def _(reg_id, reg_val): + with m.If(reg_id == instr.rp_s1): + m.d.sync += instr.s1_val.eq(reg_val) m.d.sync += instr.rp_s1.eq(0) @def_method(m, self.get_result, done) diff --git a/coreblocks/structs_common/rs.py b/coreblocks/structs_common/rs.py index eb045c3a4..255f48a63 100644 --- a/coreblocks/structs_common/rs.py +++ b/coreblocks/structs_common/rs.py @@ -19,15 +19,15 @@ def __init__( self.rs_entries_bits = (rs_entries - 1).bit_length() self.layouts = gen_params.get(RSLayouts, rs_entries_bits=self.rs_entries_bits) self.internal_layout = [ - ("rs_data", self.layouts.data_layout), + ("rs_data", self.layouts.rs.data_layout), ("rec_full", 1), ("rec_ready", 1), ("rec_reserved", 1), ] - self.insert = Method(i=self.layouts.insert_in) - self.select = Method(o=self.layouts.select_out) - self.update = Method(i=self.layouts.update_in) + self.insert = Method(i=self.layouts.rs.insert_in) + self.select = Method(o=self.layouts.rs.select_out) + self.update = Method(i=self.layouts.rs.update_in) self.take = Method(i=self.layouts.take_in, o=self.layouts.take_out) self.ready_for = [list(op_list) for op_list in ready_for] @@ -70,16 +70,16 @@ def _(rs_entry_id: Value, rs_data: Value) -> None: m.d.sync += self.data[rs_entry_id].rec_reserved.eq(1) @def_method(m, self.update) - def _(tag: Value, value: Value) -> None: + def _(reg_id: Value, reg_val: Value) -> None: for record in self.data: with m.If(record.rec_full.bool()): - with m.If(record.rs_data.rp_s1 == tag): + with m.If(record.rs_data.rp_s1 == reg_id): m.d.sync += record.rs_data.rp_s1.eq(0) - m.d.sync += record.rs_data.s1_val.eq(value) + m.d.sync += record.rs_data.s1_val.eq(reg_val) - with m.If(record.rs_data.rp_s2 == tag): + with m.If(record.rs_data.rp_s2 == reg_id): m.d.sync += record.rs_data.rp_s2.eq(0) - m.d.sync += record.rs_data.s2_val.eq(value) + m.d.sync += record.rs_data.s2_val.eq(reg_val) @def_method(m, self.take, ready=take_possible) def _(rs_entry_id: Value) -> RecordDict: diff --git a/test/frontend/test_decode.py b/test/frontend/test_decode.py index 1128df9e8..f728152f5 100644 --- a/test/frontend/test_decode.py +++ b/test/frontend/test_decode.py @@ -40,7 +40,7 @@ def setUp(self) -> None: def decode_test_proc(self): # testing an OP_IMM instruction (test copied from test_decoder.py) - yield from self.test_module.io_in.call(data=0x02A28213) + yield from self.test_module.io_in.call(instr=0x02A28213) decoded = yield from self.test_module.io_out.call() self.assertEqual(decoded["exec_fn"]["op_type"], OpType.ARITHMETIC) @@ -52,7 +52,7 @@ def decode_test_proc(self): self.assertEqual(decoded["imm"], 42) # testing an OP instruction (test copied from test_decoder.py) - yield from self.test_module.io_in.call(data=0x003100B3) + yield from self.test_module.io_in.call(instr=0x003100B3) decoded = yield from self.test_module.io_out.call() self.assertEqual(decoded["exec_fn"]["op_type"], OpType.ARITHMETIC) @@ -63,7 +63,7 @@ def decode_test_proc(self): self.assertEqual(decoded["regs_l"]["rl_s2"], 3) # testing an illegal - yield from self.test_module.io_in.call(data=0x0) + yield from self.test_module.io_in.call(instr=0x0) decoded = yield from self.test_module.io_out.call() self.assertEqual(decoded["exec_fn"]["op_type"], OpType.EXCEPTION) @@ -73,7 +73,7 @@ def decode_test_proc(self): self.assertEqual(decoded["regs_l"]["rl_s1"], 0) self.assertEqual(decoded["regs_l"]["rl_s2"], 0) - yield from self.test_module.io_in.call(data=0x0, access_fault=1) + yield from self.test_module.io_in.call(instr=0x0, access_fault=1) decoded = yield from self.test_module.io_out.call() self.assertEqual(decoded["exec_fn"]["op_type"], OpType.EXCEPTION) diff --git a/test/frontend/test_fetch.py b/test/frontend/test_fetch.py index 3817c0312..02c9726d5 100644 --- a/test/frontend/test_fetch.py +++ b/test/frontend/test_fetch.py @@ -105,7 +105,7 @@ def cache_process(): self.instr_queue.append( { - "data": data, + "instr": data, "pc": addr, "is_branch": is_branch, "next_pc": next_pc, @@ -134,7 +134,7 @@ def fetch_out_check(self): v = yield from self.m.io_out.call() self.assertEqual(v["pc"], instr["pc"]) - self.assertEqual(v["data"], instr["data"]) + self.assertEqual(v["instr"], instr["instr"]) def test(self): issue_req_mock, accept_res_mock, cache_process = self.cache_processes() diff --git a/test/lsu/test_dummylsu.py b/test/lsu/test_dummylsu.py index bfa29d532..4ac3c059a 100644 --- a/test/lsu/test_dummylsu.py +++ b/test/lsu/test_dummylsu.py @@ -23,7 +23,7 @@ def generate_register(max_reg_val: int, phys_regs_bits: int) -> tuple[int, int, rp = random.randint(1, 2**phys_regs_bits - 1) val = 0 real_val = random.randint(0, max_reg_val // 4) * 4 - ann_data = {"tag": rp, "value": real_val} + ann_data = {"reg_id": rp, "reg_val": real_val} else: rp = 0 val = random.randint(0, max_reg_val // 4) * 4 @@ -338,7 +338,7 @@ def generate_instr(self, max_reg_val, max_imm_val): rp_s2, s2_val, ann_data2, data = generate_register(0xFFFFFFFF, self.gp.phys_regs_bits) if rp_s1 == rp_s2 and ann_data1 is not None and ann_data2 is not None: ann_data2 = None - data = ann_data1["value"] + data = ann_data1["reg_val"] # decide in which order we would get announcments if random.randint(0, 1): self.announce_queue.append((ann_data1, ann_data2)) diff --git a/test/scheduler/test_rs_selection.py b/test/scheduler/test_rs_selection.py index 65c5dac2b..df6d7641f 100644 --- a/test/scheduler/test_rs_selection.py +++ b/test/scheduler/test_rs_selection.py @@ -31,8 +31,8 @@ def elaborate(self, platform): # mocked input and output m.submodules.instr_in = self.instr_in = TestbenchIO(AdapterTrans(instr_fifo.write)) m.submodules.instr_out = self.instr_out = TestbenchIO(AdapterTrans(out_fifo.read)) - m.submodules.rs1_alloc = self.rs1_alloc = TestbenchIO(Adapter(o=rs_layouts.select_out)) - m.submodules.rs2_alloc = self.rs2_alloc = TestbenchIO(Adapter(o=rs_layouts.select_out)) + m.submodules.rs1_alloc = self.rs1_alloc = TestbenchIO(Adapter(o=rs_layouts.rs.select_out)) + m.submodules.rs2_alloc = self.rs2_alloc = TestbenchIO(Adapter(o=rs_layouts.rs.select_out)) # rs selector m.submodules.selector = self.selector = RSSelection( diff --git a/test/scheduler/test_scheduler.py b/test/scheduler/test_scheduler.py index 72c9c5d3d..939cef7bd 100644 --- a/test/scheduler/test_scheduler.py +++ b/test/scheduler/test_scheduler.py @@ -60,8 +60,8 @@ def elaborate(self, platform): # mocked RS for i, rs in enumerate(self.rs): - alloc_adapter = Adapter(o=rs_layouts.select_out) - insert_adapter = Adapter(i=rs_layouts.insert_in) + alloc_adapter = Adapter(o=rs_layouts.rs.select_out) + insert_adapter = Adapter(i=rs_layouts.rs.insert_in) select_test = TestbenchIO(alloc_adapter) insert_test = TestbenchIO(insert_adapter) diff --git a/test/stages/test_backend.py b/test/stages/test_backend.py index b230ca61b..047d70517 100644 --- a/test/stages/test_backend.py +++ b/test/stages/test_backend.py @@ -23,7 +23,7 @@ def elaborate(self, platform): self.lay_result = self.gen.get(FuncUnitLayouts).accept self.lay_rob_mark_done = self.gen.get(ROBLayouts).mark_done_layout - self.lay_rs_write = self.gen.get(RSLayouts, rs_entries_bits=self.gen.max_rs_entries_bits).update_in + self.lay_rs_write = self.gen.get(RSLayouts, rs_entries_bits=self.gen.max_rs_entries_bits).rs.update_in self.lay_rf_write = self.gen.get(RFLayouts).rf_write # Initialize for each FU an FIFO which will be a stub for that FU @@ -59,8 +59,8 @@ def elaborate(self, platform): gen=self.gen, get_result=serialized_results_fifo.read, rob_mark_done=self.rob_mark_done_tbio.adapter.iface, - rs_write_val=self.rs_announce_val_tbio.adapter.iface, - rf_write_val=self.rf_announce_val_tbio.adapter.iface, + rs_update=self.rs_announce_val_tbio.adapter.iface, + rf_write=self.rf_announce_val_tbio.adapter.iface, ) return m diff --git a/test/structs_common/test_csr.py b/test/structs_common/test_csr.py index d1b1e64e4..d7fac6dc9 100644 --- a/test/structs_common/test_csr.py +++ b/test/structs_common/test_csr.py @@ -127,7 +127,7 @@ def process_test(self): yield from self.random_wait() if op["exp"]["rs1"]["rp_s1"]: - yield from self.dut.update.call(tag=op["exp"]["rs1"]["rp_s1"], value=op["exp"]["rs1"]["value"]) + yield from self.dut.update.call(reg_id=op["exp"]["rs1"]["rp_s1"], reg_val=op["exp"]["rs1"]["value"]) yield from self.random_wait() yield from self.dut.precommit.call() diff --git a/test/structs_common/test_rs.py b/test/structs_common/test_rs.py index 9cb678946..3de32528b 100644 --- a/test/structs_common/test_rs.py +++ b/test/structs_common/test_rs.py @@ -206,7 +206,7 @@ def simulation_process(self): # Update second entry first SP, instruction should be not ready value_sp1 = 1010 self.assertEqual((yield self.m.rs.data[1].rec_ready), 0) - yield from self.m.io_update.call(tag=2, value=value_sp1) + yield from self.m.io_update.call(reg_id=2, reg_val=value_sp1) yield Settle() self.assertEqual((yield self.m.rs.data[1].rs_data.rp_s1), 0) self.assertEqual((yield self.m.rs.data[1].rs_data.s1_val), value_sp1) @@ -214,18 +214,18 @@ def simulation_process(self): # Update second entry second SP, instruction should be ready value_sp2 = 2020 - yield from self.m.io_update.call(tag=3, value=value_sp2) + yield from self.m.io_update.call(reg_id=3, reg_val=value_sp2) yield Settle() self.assertEqual((yield self.m.rs.data[1].rs_data.rp_s2), 0) self.assertEqual((yield self.m.rs.data[1].rs_data.s2_val), value_sp2) self.assertEqual((yield self.m.rs.data[1].rec_ready), 1) - # Insert new insturction to entries 0 and 1, check if update of multiple tags works - tag = 4 + # Insert new instruction to entries 0 and 1, check if update of multiple registers works + reg_id = 4 value_spx = 3030 data = { - "rp_s1": tag, - "rp_s2": tag, + "rp_s1": reg_id, + "rp_s2": reg_id, "rp_dst": 1, "rob_id": 12, "exec_fn": { @@ -243,7 +243,7 @@ def simulation_process(self): yield Settle() self.assertEqual((yield self.m.rs.data[index].rec_ready), 0) - yield from self.m.io_update.call(tag=tag, value=value_spx) + yield from self.m.io_update.call(reg_id=reg_id, reg_val=value_spx) yield Settle() for index in range(2): self.assertEqual((yield self.m.rs.data[index].rs_data.rp_s1), 0) @@ -302,9 +302,9 @@ def simulation_process(self): self.assertEqual((yield self.m.rs.take.ready), 0) # Update second instuction and take it - tag = 2 + reg_id = 2 value_spx = 1 - yield from self.m.io_update.call(tag=tag, value=value_spx) + yield from self.m.io_update.call(reg_id=reg_id, reg_val=value_spx) yield Settle() self.assertEqual((yield self.m.rs.take.ready), 1) data = yield from self.m.io_take.call(rs_entry_id=1) @@ -314,11 +314,11 @@ def simulation_process(self): self.assertEqual((yield self.m.rs.take.ready), 0) # Insert two new ready instructions and take them - tag = 0 + reg_id = 0 value_spx = 3030 entry_data = { - "rp_s1": tag, - "rp_s2": tag, + "rp_s1": reg_id, + "rp_s2": reg_id, "rp_dst": 1, "rob_id": 12, "exec_fn": { diff --git a/test/test_core.py b/test/test_core.py index a44d92c12..ed5efb337 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -111,7 +111,7 @@ def get_phys_reg_val(self, reg_id): return (yield self.m.core.RF.entries[reg_id].reg_val) def push_instr(self, opcode): - yield from self.m.io_in.call(data=opcode) + yield from self.m.io_in.call(instr=opcode) def compare_core_states(self, sw_core): for i in range(self.gp.isa.reg_cnt): diff --git a/transactron/utils/_typing.py b/transactron/utils/_typing.py index d12c2611b..f82aba30c 100644 --- a/transactron/utils/_typing.py +++ b/transactron/utils/_typing.py @@ -30,7 +30,8 @@ # Internal Coreblocks types SignalBundle: TypeAlias = Signal | Record | View | Iterable["SignalBundle"] | Mapping[str, "SignalBundle"] -LayoutList: TypeAlias = list[tuple[str, "ShapeLike | LayoutList"]] +LayoutListField: TypeAlias = tuple[str, "ShapeLike | LayoutList"] +LayoutList: TypeAlias = list[LayoutListField] RecordIntDict: TypeAlias = Mapping[str, Union[int, "RecordIntDict"]] RecordIntDictRet: TypeAlias = Mapping[str, Any] # full typing hard to work with diff --git a/transactron/utils/utils.py b/transactron/utils/utils.py index 2f7b78cb6..13491cd7b 100644 --- a/transactron/utils/utils.py +++ b/transactron/utils/utils.py @@ -13,6 +13,7 @@ "assign", "OneHotSwitchDynamic", "OneHotSwitch", + "make_hashable", "flatten_signals", "align_to_power_of_two", "align_down_to_power_of_two", @@ -336,6 +337,15 @@ def layout_subset(layout: LayoutList, *, fields: set[str]) -> LayoutList: return [item for item in layout if item[0] in fields] +def make_hashable(val): + if isinstance(val, Mapping): + return frozenset(((k, make_hashable(v)) for k, v in val.items())) + elif isinstance(val, Iterable): + return (make_hashable(v) for v in val) + else: + return val + + def flatten_signals(signals: SignalBundle) -> Iterable[Signal]: """ Flattens input data, which can be either a signal, a record, a list (or a dict) of SignalBundle items.