diff --git a/coreblocks/cache/icache.py b/coreblocks/cache/icache.py index 605e22e88..08cd51784 100644 --- a/coreblocks/cache/icache.py +++ b/coreblocks/cache/icache.py @@ -11,6 +11,7 @@ from coreblocks.interface.layouts import ICacheLayouts from transactron.utils import assign, OneHotSwitchDynamic from transactron.lib import * +from transactron.lib import logging from coreblocks.peripherals.bus_adapter import BusMasterInterface from coreblocks.cache.iface import CacheInterface, CacheRefillerInterface @@ -21,19 +22,7 @@ "ICacheBypass", ] - -def extract_instr_from_word(m: TModule, params: ICacheParameters, word: Signal, addr: Value): - instr_out = Signal(params.instr_width) - if len(word) == 32: - m.d.comb += instr_out.eq(word) - elif len(word) == 64: - with m.If(addr[2] == 0): - m.d.comb += instr_out.eq(word[:32]) # Take lower 4 bytes - with m.Else(): - m.d.comb += instr_out.eq(word[32:]) # Take upper 4 bytes - else: - raise RuntimeError("Word size different than 32 and 64 is not supported") - return instr_out +log = logging.HardwareLogger("frontend.icache") class ICacheBypass(Elaboratable, CacheInterface): @@ -45,6 +34,9 @@ def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, bus_master: self.accept_res = Method(o=layouts.accept_res) self.flush = Method() + if params.words_in_fetch_block != 1: + raise ValueError("ICacheBypass only supports fetch block size equal to the word size.") + def elaborate(self, platform): m = TModule() @@ -63,7 +55,7 @@ def _(addr: Value) -> None: def _(): res = self.bus_master.get_read_response(m) return { - "instr": extract_instr_from_word(m, self.params, res.data, req_addr), + "fetch_block": res.data, "error": res.err, } @@ -82,10 +74,10 @@ class ICache(Elaboratable, CacheInterface): Refilling a cache line is abstracted away from this module. ICache module needs two methods from the refiller `refiller_start`, which is called whenever we need to refill a cache line. - `refiller_accept` should be ready to be called whenever the refiller has another word ready - to be written to cache. `refiller_accept` should set `last` bit when either an error occurs - or the transfer is over. After issuing `last` bit, `refiller_accept` shouldn't be ready until - the next transfer is started. + `refiller_accept` should be ready to be called whenever the refiller has another fetch block + ready to be written to cache. `refiller_accept` should set `last` bit when either an error + occurs or the transfer is over. After issuing `last` bit, `refiller_accept` shouldn't be ready + until the next transfer is started. """ def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, refiller: CacheRefillerInterface) -> None: @@ -150,14 +142,13 @@ def elaborate(self, platform): ] m.submodules.mem = self.mem = ICacheMemory(self.params) - m.submodules.req_fifo = self.req_fifo = FIFO(layout=self.addr_layout, depth=2) - m.submodules.res_fwd = self.res_fwd = Forwarder(layout=self.layouts.accept_res) + m.submodules.req_zipper = req_zipper = ArgumentsToResultsZipper(self.addr_layout, self.layouts.accept_res) # State machine logic needs_refill = Signal() refill_finish = Signal() - refill_finish_last = Signal() refill_error = Signal() + refill_error_saved = Signal() flush_start = Signal() flush_finish = Signal() @@ -166,6 +157,7 @@ def elaborate(self, platform): self.perf_flushes.incr(m, cond=flush_finish) with m.FSM(reset="FLUSH") as fsm: + with m.State("FLUSH"): with m.If(flush_finish): m.next = "LOOKUP" @@ -188,37 +180,44 @@ def elaborate(self, platform): m.d.sync += way_selector.eq(way_selector.rotate_left(1)) # Fast path - read requests - request_valid = self.req_fifo.read.ready - request_addr = Signal(self.addr_layout) + mem_read_addr = Signal(self.addr_layout) + prev_mem_read_addr = Signal(self.addr_layout) + m.d.comb += assign(mem_read_addr, prev_mem_read_addr) - tag_hit = [tag_data.valid & (tag_data.tag == request_addr.tag) for tag_data in self.mem.tag_rd_data] - tag_hit_any = reduce(operator.or_, tag_hit) + mem_read_output_valid = Signal() + with Transaction(name="MemRead").body( + m, request=fsm.ongoing("LOOKUP") & (mem_read_output_valid | refill_error_saved) + ): + req_addr = req_zipper.peek_arg(m) - mem_out = Signal(self.params.word_width) - for i in OneHotSwitchDynamic(m, Cat(tag_hit)): - m.d.comb += mem_out.eq(self.mem.data_rd_data[i]) + tag_hit = [tag_data.valid & (tag_data.tag == req_addr.tag) for tag_data in self.mem.tag_rd_data] + tag_hit_any = reduce(operator.or_, tag_hit) - instr_out = extract_instr_from_word(m, self.params, mem_out, Value.cast(request_addr)) + with m.If(tag_hit_any | refill_error_saved): + self.perf_hits.incr(m, cond=tag_hit_any) + mem_out = Signal(self.params.fetch_block_bytes * 8) + for i in OneHotSwitchDynamic(m, Cat(tag_hit)): + m.d.av_comb += mem_out.eq(self.mem.data_rd_data[i]) - refill_error_saved = Signal() - m.d.comb += needs_refill.eq(request_valid & ~tag_hit_any & ~refill_error_saved) + req_zipper.write_results(m, fetch_block=mem_out, error=refill_error_saved) + m.d.sync += refill_error_saved.eq(0) + m.d.sync += mem_read_output_valid.eq(0) + with m.Else(): + self.perf_misses.incr(m) - with Transaction().body(m, request=request_valid & fsm.ongoing("LOOKUP") & (tag_hit_any | refill_error_saved)): - self.perf_errors.incr(m, cond=refill_error_saved) - self.perf_misses.incr(m, cond=refill_finish_last) - self.perf_hits.incr(m, cond=~refill_finish_last) + m.d.comb += needs_refill.eq(1) - self.res_fwd.write(m, instr=instr_out, error=refill_error_saved) - m.d.sync += refill_error_saved.eq(0) + # Align to the beginning of the cache line + aligned_addr = self.serialize_addr(req_addr) & ~((1 << self.params.offset_bits) - 1) + log.debug(m, True, "Refilling line 0x{:x}", aligned_addr) + self.refiller.start_refill(m, addr=aligned_addr) @def_method(m, self.accept_res) def _(): - self.req_fifo.read(m) self.req_latency.stop(m) - return self.res_fwd.read(m) - mem_read_addr = Signal(self.addr_layout) - m.d.comb += assign(mem_read_addr, request_addr) + output = req_zipper.read(m) + return output.results @def_method(m, self.issue_req, ready=accepting_requests) def _(addr: Value) -> None: @@ -226,11 +225,11 @@ def _(addr: Value) -> None: self.req_latency.start(m) deserialized = self.deserialize_addr(addr) - # Forward read address only if the method is called m.d.comb += assign(mem_read_addr, deserialized) - m.d.sync += assign(request_addr, deserialized) + m.d.sync += assign(prev_mem_read_addr, deserialized) + req_zipper.write_args(m, deserialized) - self.req_fifo.write(m, deserialized) + m.d.sync += mem_read_output_valid.eq(1) m.d.comb += [ self.mem.tag_rd_index.eq(mem_read_addr.index), @@ -245,34 +244,30 @@ def _(addr: Value) -> None: @def_method(m, self.flush, ready=accepting_requests) def _() -> None: + log.info(m, True, "Flushing the cache...") m.d.sync += flush_index.eq(0) m.d.comb += flush_start.eq(1) m.d.comb += flush_finish.eq(flush_index == self.params.num_of_sets - 1) # Slow path - data refilling - with Transaction().body(m, request=fsm.ongoing("LOOKUP") & needs_refill): - # Align to the beginning of the cache line - aligned_addr = self.serialize_addr(request_addr) & ~((1 << self.params.offset_bits) - 1) - self.refiller.start_refill(m, addr=aligned_addr) - - m.d.sync += refill_finish_last.eq(0) - with Transaction().body(m): ret = self.refiller.accept_refill(m) deserialized = self.deserialize_addr(ret.addr) + self.perf_errors.incr(m, cond=ret.error) + m.d.top_comb += [ self.mem.data_wr_addr.index.eq(deserialized["index"]), self.mem.data_wr_addr.offset.eq(deserialized["offset"]), - self.mem.data_wr_data.eq(ret.data), + self.mem.data_wr_data.eq(ret.fetch_block), ] m.d.comb += self.mem.data_wr_en.eq(1) m.d.comb += refill_finish.eq(ret.last) - m.d.sync += refill_finish_last.eq(1) m.d.comb += refill_error.eq(ret.error) - m.d.sync += refill_error_saved.eq(ret.error) + with m.If(ret.error): + m.d.sync += refill_error_saved.eq(1) with m.If(fsm.ongoing("FLUSH")): m.d.comb += [ @@ -285,9 +280,9 @@ def _() -> None: with m.Else(): m.d.comb += [ self.mem.way_wr_en.eq(way_selector), - self.mem.tag_wr_index.eq(request_addr.index), + self.mem.tag_wr_index.eq(mem_read_addr.index), self.mem.tag_wr_data.valid.eq(~refill_error), - self.mem.tag_wr_data.tag.eq(request_addr.tag), + self.mem.tag_wr_data.tag.eq(mem_read_addr.tag), self.mem.tag_wr_en.eq(refill_finish), ] @@ -301,7 +296,7 @@ class ICacheMemory(Elaboratable): Writes are multiplexed using one-hot `way_wr_en` signal. Read data lines from all ways are separately exposed (as an array). - The data memory is addressed using a machine word. + The data memory is addressed using fetch blocks. """ def __init__(self, params: ICacheParameters) -> None: @@ -319,11 +314,13 @@ def __init__(self, params: ICacheParameters) -> None: self.data_addr_layout = make_layout(("index", self.params.index_bits), ("offset", self.params.offset_bits)) + self.fetch_block_bits = params.fetch_block_bytes * 8 + self.data_rd_addr = Signal(self.data_addr_layout) - self.data_rd_data = Array([Signal(self.params.word_width) for _ in range(self.params.num_of_ways)]) + self.data_rd_data = Array([Signal(self.fetch_block_bits) for _ in range(self.params.num_of_ways)]) self.data_wr_addr = Signal(self.data_addr_layout) self.data_wr_en = Signal() - self.data_wr_data = Signal(self.params.word_width) + self.data_wr_data = Signal(self.fetch_block_bits) def elaborate(self, platform): m = TModule() @@ -345,17 +342,18 @@ def elaborate(self, platform): tag_mem_wp.en.eq(self.tag_wr_en & way_wr), ] - data_mem = Memory(width=self.params.word_width, depth=self.params.num_of_sets * self.params.words_in_block) + data_mem = Memory( + width=self.fetch_block_bits, depth=self.params.num_of_sets * self.params.fetch_blocks_in_line + ) data_mem_rp = data_mem.read_port() data_mem_wp = data_mem.write_port() m.submodules[f"data_mem_{i}_rp"] = data_mem_rp m.submodules[f"data_mem_{i}_wp"] = data_mem_wp - # We address the data RAM using machine words, so we have to + # We address the data RAM using fetch blocks, so we have to # discard a few least significant bits from the address. - redundant_offset_bits = exact_log2(self.params.word_width_bytes) - rd_addr = Cat(self.data_rd_addr.offset, self.data_rd_addr.index)[redundant_offset_bits:] - wr_addr = Cat(self.data_wr_addr.offset, self.data_wr_addr.index)[redundant_offset_bits:] + rd_addr = Cat(self.data_rd_addr.offset, self.data_rd_addr.index)[self.params.fetch_block_bytes_log :] + wr_addr = Cat(self.data_wr_addr.offset, self.data_wr_addr.index)[self.params.fetch_block_bytes_log :] m.d.comb += [ self.data_rd_data[i].eq(data_mem_rp.data), diff --git a/coreblocks/cache/iface.py b/coreblocks/cache/iface.py index c2c54d2ff..95bb00fd9 100644 --- a/coreblocks/cache/iface.py +++ b/coreblocks/cache/iface.py @@ -35,7 +35,7 @@ class CacheRefillerInterface(HasElaborate, Protocol): start_refill : Method A method that is used to start a refill for a given cache line. accept_refill : Method - A method that is used to accept one word from the requested cache line. + A method that is used to accept one fetch block from the requested cache line. """ start_refill: Method diff --git a/coreblocks/cache/refiller.py b/coreblocks/cache/refiller.py index 311764852..92fea2911 100644 --- a/coreblocks/cache/refiller.py +++ b/coreblocks/cache/refiller.py @@ -14,6 +14,7 @@ class SimpleCommonBusCacheRefiller(Elaboratable, CacheRefillerInterface): def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, bus_master: BusMasterInterface): + self.layouts = layouts self.params = params self.bus_master = bus_master @@ -23,51 +24,84 @@ def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, bus_master: def elaborate(self, platform): m = TModule() - refill_address = Signal(self.params.word_width - self.params.offset_bits) + m.submodules.resp_fwd = resp_fwd = Forwarder(self.layouts.accept_refill) + + cache_line_address = Signal(self.params.word_width - self.params.offset_bits) + refill_active = Signal() - word_counter = Signal(range(self.params.words_in_block)) + flushing = Signal() - m.submodules.address_fwd = address_fwd = Forwarder( - [("word_counter", word_counter.shape()), ("refill_address", refill_address.shape())] - ) + sending_requests = Signal() + req_word_counter = Signal(range(self.params.words_in_line)) - with Transaction().body(m): - address = address_fwd.read(m) + with Transaction().body(m, request=sending_requests): self.bus_master.request_read( m, - addr=Cat(address["word_counter"], address["refill_address"]), + addr=Cat(req_word_counter, cache_line_address), sel=C(1).replicate(self.bus_master.params.data_width // self.bus_master.params.granularity), ) - @def_method(m, self.start_refill, ready=~refill_active) - def _(addr) -> None: - address = addr[self.params.offset_bits :] - m.d.sync += refill_address.eq(address) - m.d.sync += refill_active.eq(1) - m.d.sync += word_counter.eq(0) + m.d.sync += req_word_counter.eq(req_word_counter + 1) + with m.If(req_word_counter == (self.params.words_in_line - 1)): + m.d.sync += sending_requests.eq(0) - address_fwd.write(m, word_counter=0, refill_address=address) + resp_word_counter = Signal(range(self.params.words_in_line)) + block_buffer = Signal(self.params.word_width * (self.params.words_in_fetch_block - 1)) - @def_method(m, self.accept_refill, ready=refill_active) - def _(): - fetched = self.bus_master.get_read_response(m) + # The transaction reads responses from the bus, builds the fetch block and when + # receives the last word of the fetch block, dispatches it. + with Transaction().body(m): + bus_response = self.bus_master.get_read_response(m) + + block = Signal(self.params.fetch_block_bytes * 8) + m.d.av_comb += block.eq(Cat(block_buffer, bus_response.data)) + m.d.sync += block_buffer.eq(block[self.params.word_width :]) + + words_in_fetch_block_log = exact_log2(self.params.words_in_fetch_block) + current_fetch_block = resp_word_counter[words_in_fetch_block_log:] + word_in_fetch_block = resp_word_counter[:words_in_fetch_block_log] + + with m.If(~flushing): + with m.If((word_in_fetch_block == self.params.words_in_fetch_block - 1) | bus_response.err): + fetch_block_addr = Cat( + C(0, exact_log2(self.params.word_width_bytes)), + C(0, words_in_fetch_block_log), + current_fetch_block, + cache_line_address, + ) + + resp_fwd.write( + m, + addr=fetch_block_addr, + fetch_block=block, + error=bus_response.err, + last=(resp_word_counter == self.params.words_in_line - 1) | bus_response.err, + ) + + with m.If(resp_word_counter == self.params.words_in_line - 1): + m.d.sync += refill_active.eq(0) + with m.Elif(bus_response.err): + m.d.sync += sending_requests.eq(0) + m.d.sync += flushing.eq(1) + + m.d.sync += resp_word_counter.eq(resp_word_counter + 1) + + with m.If(flushing & (resp_word_counter == req_word_counter)): + m.d.sync += refill_active.eq(0) + m.d.sync += flushing.eq(0) - last = (word_counter == (self.params.words_in_block - 1)) | fetched.err + @def_method(m, self.start_refill, ready=~refill_active) + def _(addr) -> None: + m.d.sync += cache_line_address.eq(addr[self.params.offset_bits :]) + m.d.sync += req_word_counter.eq(0) + m.d.sync += sending_requests.eq(1) - next_word_counter = Signal.like(word_counter) - m.d.top_comb += next_word_counter.eq(word_counter + 1) + m.d.sync += resp_word_counter.eq(0) - m.d.sync += word_counter.eq(next_word_counter) - with m.If(last): - m.d.sync += refill_active.eq(0) - with m.Else(): - address_fwd.write(m, word_counter=next_word_counter, refill_address=refill_address) + m.d.sync += refill_active.eq(1) - return { - "addr": Cat(C(0, exact_log2(self.params.word_width_bytes)), word_counter, refill_address), - "data": fetched.data, - "error": fetched.err, - "last": last, - } + @def_method(m, self.accept_refill) + def _(): + return resp_fwd.read(m) return m diff --git a/coreblocks/frontend/decoder/isa.py b/coreblocks/frontend/decoder/isa.py index 229d65c9b..10bb72854 100644 --- a/coreblocks/frontend/decoder/isa.py +++ b/coreblocks/frontend/decoder/isa.py @@ -40,6 +40,7 @@ class Opcode(IntEnum, shape=5): JALR = 0b11001 JAL = 0b11011 SYSTEM = 0b11100 + RESERVED = 0b11111 class Funct3(IntEnum, shape=3): diff --git a/coreblocks/frontend/decoder/rvc.py b/coreblocks/frontend/decoder/rvc.py index 4ff48c07d..2fe9d42ee 100644 --- a/coreblocks/frontend/decoder/rvc.py +++ b/coreblocks/frontend/decoder/rvc.py @@ -209,7 +209,7 @@ def _quadrant_2(self) -> list[DecodedInstr]: shamt = Cat(self.instr_in[2:7], self.instr_in[12]) ldsp_imm = Cat(C(0, 3), self.instr_in[5:7], self.instr_in[12], self.instr_in[2:5], C(0, 3)) lwsp_imm = Cat(C(0, 2), self.instr_in[4:7], self.instr_in[12], self.instr_in[2:4], C(0, 4)) - sdsp_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[7:10], C(0, 2)) + sdsp_imm = Cat(C(0, 3), self.instr_in[10:13], self.instr_in[7:10], C(0, 3)) swsp_imm = Cat(C(0, 2), self.instr_in[9:13], self.instr_in[7:9], C(0, 4)) slli = ( diff --git a/coreblocks/frontend/fetch/fetch.py b/coreblocks/frontend/fetch/fetch.py index add09c6c1..0901dc451 100644 --- a/coreblocks/frontend/fetch/fetch.py +++ b/coreblocks/frontend/fetch/fetch.py @@ -40,6 +40,9 @@ def __init__(self, gen_params: GenParams, icache: CacheInterface, cont: Method) # ExceptionCauseRegister uses separate Transaction for it, so performace is not affected. self.stall_exception.add_conflict(self.resume, Priority.LEFT) + # For now assume that the fetch block is 4 bytes long (a machine word). + assert self.gen_params.fetch_block_bytes == 4 + def elaborate(self, platform): m = TModule() @@ -74,7 +77,7 @@ def stall(exception=False): target = self.fetch_target_queue.read(m) res = self.icache.accept_res(m) - opcode = res.instr[2:7] + opcode = res.fetch_block[2:7] # whether we have to wait for the retirement of this instruction before we make futher speculation unsafe_instr = opcode == Opcode.SYSTEM @@ -90,7 +93,7 @@ def stall(exception=False): with m.If(unsafe_instr): stall() - m.d.comb += instr.eq(res.instr) + m.d.comb += instr.eq(res.fetch_block) self.cont(m, instr=instr, pc=target.addr, access_fault=fetch_error, rvc=0) @@ -136,6 +139,9 @@ def __init__(self, gen_params: GenParams, icache: CacheInterface, cont: Method) self.perf_rvc = HwCounter("frontend.ifu.rvc", "Number of decompressed RVC instructions") + # For now assume that the fetch block is 4 bytes long (a machine word). + assert self.gen_params.fetch_block_bytes == 4 + def elaborate(self, platform) -> TModule: m = TModule() @@ -175,8 +181,8 @@ def elaborate(self, platform) -> TModule: req_limiter.release(m) is_unaligned = current_pc[1] - resp_upper_half = cache_resp.instr[16:] - resp_lower_half = cache_resp.instr[:16] + resp_upper_half = cache_resp.fetch_block[16:] + resp_lower_half = cache_resp.fetch_block[:16] resp_first_half = Mux(is_unaligned, resp_upper_half, resp_lower_half) resp_valid = ~flushing & (cache_resp.error == 0) is_resp_upper_rvc = Signal() @@ -188,7 +194,7 @@ def elaborate(self, platform) -> TModule: is_rvc = is_instr_compressed(instr_lo_half) - full_instr = Mux(half_instr_buff_v, Cat(half_instr_buff, resp_lower_half), cache_resp.instr) + full_instr = Mux(half_instr_buff_v, Cat(half_instr_buff, resp_lower_half), cache_resp.fetch_block) instr = Signal(32) m.d.top_comb += instr.eq(Mux(is_rvc, decompress.instr_out, full_instr)) diff --git a/coreblocks/func_blocks/fu/alu.py b/coreblocks/func_blocks/fu/alu.py index adfcc6a3f..d824cacb3 100644 --- a/coreblocks/func_blocks/fu/alu.py +++ b/coreblocks/func_blocks/fu/alu.py @@ -3,6 +3,7 @@ from transactron import * from transactron.lib import FIFO +from transactron.lib.metrics import * from coreblocks.frontend.decoder.isa import Funct3, Funct7 from coreblocks.frontend.decoder.optypes import OpType @@ -219,9 +220,17 @@ def __init__(self, gen_params: GenParams, alu_fn=AluFn()): self.issue = Method(i=layouts.issue) self.accept = Method(o=layouts.accept) + self.perf_instr = TaggedCounter( + "backend.fu.alu.instr", + "Counts of instructions executed by the jumpbranch unit", + tags=AluFn.Fn, + ) + def elaborate(self, platform): m = TModule() + m.submodules += [self.perf_instr] + m.submodules.alu = alu = Alu(self.gen_params, alu_fn=self.alu_fn) m.submodules.fifo = fifo = FIFO(self.gen_params.get(FuncUnitLayouts).accept, 2) m.submodules.decoder = decoder = self.alu_fn.get_decoder(self.gen_params) @@ -238,6 +247,8 @@ def _(arg): m.d.comb += alu.in1.eq(arg.s1_val) m.d.comb += alu.in2.eq(Mux(arg.imm, arg.imm, arg.s2_val)) + self.perf_instr.incr(m, decoder.decode_fn) + fifo.write(m, rob_id=arg.rob_id, result=alu.out, rp_dst=arg.rp_dst, exception=0) return m diff --git a/coreblocks/func_blocks/fu/jumpbranch.py b/coreblocks/func_blocks/fu/jumpbranch.py index aeb6fed22..9730650ee 100644 --- a/coreblocks/func_blocks/fu/jumpbranch.py +++ b/coreblocks/func_blocks/fu/jumpbranch.py @@ -136,8 +136,11 @@ def __init__(self, gen_params: GenParams, jb_fn=JumpBranchFn()): self.dm = gen_params.get(DependencyManager) self.dm.add_dependency(BranchVerifyKey(), self.fifo_branch_resolved.read) - self.perf_jumps = HwCounter("backend.fu.jumpbranch.jumps", "Number of jump instructions issued") - self.perf_branches = HwCounter("backend.fu.jumpbranch.branches", "Number of branch instructions issued") + self.perf_instr = TaggedCounter( + "backend.fu.jumpbranch.instr", + "Counts of instructions executed by the jumpbranch unit", + tags=JumpBranchFn.Fn, + ) self.perf_misaligned = HwCounter( "backend.fu.jumpbranch.misaligned", "Number of instructions with misaligned target address" ) @@ -145,7 +148,10 @@ def __init__(self, gen_params: GenParams, jb_fn=JumpBranchFn()): def elaborate(self, platform): m = TModule() - m.submodules += [self.perf_jumps, self.perf_branches, self.perf_misaligned] + m.submodules += [ + self.perf_instr, + self.perf_misaligned, + ] m.submodules.jb = jb = JumpBranch(self.gen_params, fn=self.jb_fn) m.submodules.fifo_res = fifo_res = FIFO(self.gen_params.get(FuncUnitLayouts).accept, 2) @@ -169,12 +175,10 @@ def _(arg): m.d.top_comb += jb.in_rvc.eq(arg.exec_fn.funct7) is_auipc = decoder.decode_fn == JumpBranchFn.Fn.AUIPC - is_jump = (decoder.decode_fn == JumpBranchFn.Fn.JAL) | (decoder.decode_fn == JumpBranchFn.Fn.JALR) jump_result = Mux(jb.taken, jb.jmp_addr, jb.reg_res) - self.perf_jumps.incr(m, cond=is_jump) - self.perf_branches.incr(m, cond=(~is_jump & ~is_auipc)) + self.perf_instr.incr(m, decoder.decode_fn) exception = Signal() exception_report = self.dm.get_dependency(ExceptionReportKey()) @@ -216,7 +220,7 @@ def _(arg): log.debug( m, True, - "jumping from 0x{:08x} to 0x{:08x}; misprediction: {}", + "branch resolved from 0x{:08x} to 0x{:08x}; misprediction: {}", jb.in_pc, jump_result, misprediction, diff --git a/coreblocks/interface/layouts.py b/coreblocks/interface/layouts.py index 5db15302e..0e831f033 100644 --- a/coreblocks/interface/layouts.py +++ b/coreblocks/interface/layouts.py @@ -392,13 +392,16 @@ class ICacheLayouts: def __init__(self, gen_params: GenParams): fields = gen_params.get(CommonLayoutFields) - self.error: LayoutListField = ("last", 1) + self.last: LayoutListField = ("last", 1) """This is the last cache refill result.""" + self.fetch_block: LayoutListField = ("fetch_block", gen_params.fetch_block_bytes * 8) + """The block of data the fetch unit operates on.""" + self.issue_req = make_layout(fields.addr) self.accept_res = make_layout( - fields.instr, + self.fetch_block, fields.error, ) @@ -408,9 +411,9 @@ def __init__(self, gen_params: GenParams): self.accept_refill = make_layout( fields.addr, - fields.data, + self.fetch_block, fields.error, - self.error, + self.last, ) diff --git a/coreblocks/params/configurations.py b/coreblocks/params/configurations.py index c8dd6810c..1d17289f5 100644 --- a/coreblocks/params/configurations.py +++ b/coreblocks/params/configurations.py @@ -62,8 +62,10 @@ class CoreConfiguration: Associativity of the instruction cache. icache_sets_bits: int Log of the number of sets of the instruction cache. - icache_block_size_bits: int + icache_line_bytes_log: int Log of the cache line size (in bytes). + fetch_block_bytes_log: int + Log of the size of the fetch block (in bytes). allow_partial_extensions: bool Allow partial support of extensions. _implied_extensions: Extenstion @@ -93,7 +95,9 @@ def __post_init__(self): icache_enable: bool = True icache_ways: int = 2 icache_sets_bits: int = 7 - icache_block_size_bits: int = 5 + icache_line_bytes_log: int = 5 + + fetch_block_bytes_log: int = 2 allow_partial_extensions: bool = False diff --git a/coreblocks/params/genparams.py b/coreblocks/params/genparams.py index 5b6fe0ce2..33dd5346c 100644 --- a/coreblocks/params/genparams.py +++ b/coreblocks/params/genparams.py @@ -35,16 +35,17 @@ def __init__(self, cfg: CoreConfiguration): self.pma = cfg.pma bytes_in_word = self.isa.xlen // 8 - self.wb_params = WishboneParameters( - data_width=self.isa.xlen, addr_width=self.isa.xlen - exact_log2(bytes_in_word) - ) + bytes_in_word_log = exact_log2(bytes_in_word) + self.wb_params = WishboneParameters(data_width=self.isa.xlen, addr_width=self.isa.xlen - bytes_in_word_log) self.icache_params = ICacheParameters( addr_width=self.isa.xlen, word_width=self.isa.xlen, + fetch_block_bytes_log=cfg.fetch_block_bytes_log, num_of_ways=cfg.icache_ways, num_of_sets_bits=cfg.icache_sets_bits, - block_size_bits=cfg.icache_block_size_bits, + line_bytes_log=cfg.icache_line_bytes_log, + enable=cfg.icache_enable, ) self.debug_signals_enabled = cfg.debug_signals @@ -65,4 +66,9 @@ def __init__(self, cfg: CoreConfiguration): self.max_rs_entries_bits = (self.max_rs_entries - 1).bit_length() self.start_pc = cfg.start_pc + self.fetch_block_bytes_log = cfg.fetch_block_bytes_log + if self.fetch_block_bytes_log < bytes_in_word_log: + raise ValueError("Fetch block must be not smaller than the machine word.") + self.fetch_block_bytes = 2**self.fetch_block_bytes_log + self._toolchain_isa_str = gen_isa_string(extensions, cfg.xlen, skip_internal=True) diff --git a/coreblocks/params/icache_params.py b/coreblocks/params/icache_params.py index 2506d7b37..e71a07bf9 100644 --- a/coreblocks/params/icache_params.py +++ b/coreblocks/params/icache_params.py @@ -11,35 +11,49 @@ class ICacheParameters: Associativity of the cache. num_of_sets_bits : int Log of the number of cache sets. - block_size_bits : int - Log of the size of a single cache block in bytes. + line_bytes_log : int + Log of the size of a single cache line in bytes. enable : bool Enable the instruction cache. If disabled, requestes are bypassed to the bus. """ - def __init__(self, *, addr_width, word_width, num_of_ways, num_of_sets_bits, block_size_bits, enable=True): + def __init__( + self, + *, + addr_width, + word_width, + fetch_block_bytes_log, + num_of_ways, + num_of_sets_bits, + line_bytes_log, + enable=True + ): self.addr_width = addr_width self.word_width = word_width + self.fetch_block_bytes_log = fetch_block_bytes_log self.num_of_ways = num_of_ways self.num_of_sets_bits = num_of_sets_bits - self.block_size_bits = block_size_bits + self.line_bytes_log = line_bytes_log self.enable = enable + self.fetch_block_bytes = 2**fetch_block_bytes_log self.num_of_sets = 2**num_of_sets_bits - self.block_size_bytes = 2**block_size_bits - - # We are sanely assuming that the instruction width is 4 bytes. - self.instr_width = 32 + self.line_size_bytes = 2**line_bytes_log self.word_width_bytes = word_width // 8 - if self.block_size_bytes % self.word_width_bytes != 0: - raise ValueError("block_size_bytes must be divisble by the machine word size") - - self.offset_bits = block_size_bits + self.offset_bits = line_bytes_log self.index_bits = num_of_sets_bits self.tag_bits = self.addr_width - self.offset_bits - self.index_bits self.index_start_bit = self.offset_bits self.index_end_bit = self.offset_bits + self.index_bits - 1 - self.words_in_block = self.block_size_bytes // self.word_width_bytes + self.words_in_line = self.line_size_bytes // self.word_width_bytes + self.words_in_fetch_block = self.fetch_block_bytes // self.word_width_bytes + self.fetch_blocks_in_line = self.line_size_bytes // self.fetch_block_bytes + + if not enable: + return + + if line_bytes_log < self.fetch_block_bytes_log: + raise ValueError("The instruction cache line size must be not smaller than the fetch block size.") diff --git a/coreblocks/params/instr.py b/coreblocks/params/instr.py index 8f14d6b11..f3755b25d 100644 --- a/coreblocks/params/instr.py +++ b/coreblocks/params/instr.py @@ -1,4 +1,13 @@ -from abc import abstractmethod, ABC +""" + +Based on riscv-python-model by Stefan Wallentowitz +https://github.com/wallento/riscv-python-model +""" + +from dataclasses import dataclass +from abc import ABC +from enum import Enum +from typing import Optional from amaranth.hdl import ValueCastable from amaranth import * @@ -9,6 +18,7 @@ __all__ = [ + "RISCVInstr", "RTypeInstr", "ITypeInstr", "STypeInstr", @@ -20,112 +30,219 @@ ] +@dataclass(kw_only=True) +class Field: + """Information about a field in a RISC-V instruction. + + Attributes + ---------- + base: int | list[int] + A bit position (or a list of positions) where this field (or parts of the field) + would map in the instruction. + size: int | list[int] + Size (or sizes of the parts) of the field + signed: bool + Whether this field encodes a signed value. + offset: int + How many bits of this field should be skipped when encoding the instruction. + For example, the immediate of the jump instruction always skips the least + significant bit. This only affects encoding procedures, so externally (for example + when creating an instance of a instruction) full-size values should be always used. + static_value: Optional[Value] + Whether the field should have a static value for a given type of an instruction. + """ + + base: int | list[int] + size: int | list[int] + + signed: bool = False + offset: int = 0 + static_value: Optional[Value] = None + + _name: str = "" + + def bases(self) -> list[int]: + return [self.base] if isinstance(self.base, int) else self.base + + def sizes(self) -> list[int]: + return [self.size] if isinstance(self.size, int) else self.size + + def shape(self) -> Shape: + return Shape(width=sum(self.sizes()) + self.offset, signed=self.signed) + + def __set_name__(self, owner, name): + self._name = name + + def __get__(self, obj, objtype=None) -> Value: + if self.static_value is not None: + return self.static_value + + return obj.__dict__.get(self._name, C(0, self.shape())) + + def __set__(self, obj, value) -> None: + if self.static_value is not None: + raise AttributeError("Can't overwrite the static value of a field.") + + expected_shape = self.shape() + + field_val: Value = C(0) + if isinstance(value, Enum): + field_val = Const(value.value, expected_shape) + elif isinstance(value, int): + field_val = Const(value, expected_shape) + else: + field_val = Value.cast(value) + + if field_val.shape().width != expected_shape.width: + raise AttributeError( + f"Expected width of the value: {expected_shape.width}, given: {field_val.shape().width}" + ) + if field_val.shape().signed and not expected_shape.signed: + raise AttributeError( + f"Expected signedness of the value: {expected_shape.signed}, given: {field_val.shape().signed}" + ) + + obj.__dict__[self._name] = field_val + + def get_parts(self, value: Value) -> list[Value]: + base = self.bases() + size = self.sizes() + offset = self.offset + + ret: list[Value] = [] + for i in range(len(base)): + ret.append(value[offset : offset + size[i]]) + offset += size[i] + + return ret + + +def _get_fields(cls: type) -> list[Field]: + fields = [cls.__dict__[member] for member in vars(cls) if isinstance(cls.__dict__[member], Field)] + field_ids = set([id(field) for field in fields]) + for base in cls.__bases__: + for field in _get_fields(base): + if id(field) in field_ids: + continue + fields.append(field) + field_ids.add(id(field)) + + return fields + + class RISCVInstr(ABC, ValueCastable): - @abstractmethod - def pack(self) -> Value: - pass + opcode = Field(base=0, size=7) + + def __init__(self, opcode: Opcode): + self.opcode = Cat(C(0b11, 2), opcode) + + def encode(self) -> int: + const = Const.cast(self.as_value()) + return const.value # type: ignore @ValueCastable.lowermethod - def as_value(self): - return self.pack() + def as_value(self) -> Value: + parts: list[tuple[int, Value]] = [] + + for field in _get_fields(type(self)): + value = field.__get__(self, type(self)) + parts += zip(field.bases(), field.get_parts(value)) + + parts.sort() + return Cat([part[1] for part in parts]) - def shape(self): + def shape(self) -> Shape: return self.as_value().shape() -class RTypeInstr(RISCVInstr): +class InstructionFunct3Type(RISCVInstr): + funct3 = Field(base=12, size=3) + + +class InstructionFunct7Type(RISCVInstr): + funct7 = Field(base=25, size=7) + + +class RTypeInstr(InstructionFunct3Type, InstructionFunct7Type): + rd = Field(base=7, size=5) + rs1 = Field(base=15, size=5) + rs2 = Field(base=20, size=5) + def __init__( - self, - opcode: ValueLike, - rd: ValueLike, - funct3: ValueLike, - rs1: ValueLike, - rs2: ValueLike, - funct7: ValueLike, + self, opcode: Opcode, funct3: ValueLike, funct7: ValueLike, rd: ValueLike, rs1: ValueLike, rs2: ValueLike ): - self.opcode = Value.cast(opcode) - self.rd = Value.cast(rd) - self.funct3 = Value.cast(funct3) - self.rs1 = Value.cast(rs1) - self.rs2 = Value.cast(rs2) - self.funct7 = Value.cast(funct7) - - def pack(self) -> Value: - return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.rs2, self.funct7) - - -class ITypeInstr(RISCVInstr): - def __init__(self, opcode: ValueLike, rd: ValueLike, funct3: ValueLike, rs1: ValueLike, imm: ValueLike): - self.opcode = Value.cast(opcode) - self.rd = Value.cast(rd) - self.funct3 = Value.cast(funct3) - self.rs1 = Value.cast(rs1) - self.imm = Value.cast(imm) - - def pack(self) -> Value: - return Cat(C(0b11, 2), self.opcode, self.rd, self.funct3, self.rs1, self.imm) - - -class STypeInstr(RISCVInstr): - def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike): - self.opcode = Value.cast(opcode) - self.imm = Value.cast(imm) - self.funct3 = Value.cast(funct3) - self.rs1 = Value.cast(rs1) - self.rs2 = Value.cast(rs2) - - def pack(self) -> Value: - return Cat(C(0b11, 2), self.opcode, self.imm[0:5], self.funct3, self.rs1, self.rs2, self.imm[5:12]) - - -class BTypeInstr(RISCVInstr): - def __init__(self, opcode: ValueLike, imm: ValueLike, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike): - self.opcode = Value.cast(opcode) - self.imm = Value.cast(imm) - self.funct3 = Value.cast(funct3) - self.rs1 = Value.cast(rs1) - self.rs2 = Value.cast(rs2) - - def pack(self) -> Value: - return Cat( - C(0b11, 2), - self.opcode, - self.imm[11], - self.imm[1:5], - self.funct3, - self.rs1, - self.rs2, - self.imm[5:11], - self.imm[12], - ) + super().__init__(opcode) + self.funct3 = funct3 + self.funct7 = funct7 + self.rd = rd + self.rs1 = rs1 + self.rs2 = rs2 + + +class ITypeInstr(InstructionFunct3Type): + rd = Field(base=7, size=5) + rs1 = Field(base=15, size=5) + imm = Field(base=20, size=12, signed=True) + + def __init__(self, opcode: Opcode, funct3: ValueLike, rd: ValueLike, rs1: ValueLike, imm: ValueLike): + super().__init__(opcode) + self.funct3 = funct3 + self.rd = rd + self.rs1 = rs1 + self.imm = imm + + +class STypeInstr(InstructionFunct3Type): + rs1 = Field(base=15, size=5) + rs2 = Field(base=20, size=5) + imm = Field(base=[7, 25], size=[5, 7], signed=True) + + def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike): + super().__init__(opcode) + self.funct3 = funct3 + self.rs1 = rs1 + self.rs2 = rs2 + self.imm = imm + + +class BTypeInstr(InstructionFunct3Type): + rs1 = Field(base=15, size=5) + rs2 = Field(base=20, size=5) + imm = Field(base=[8, 25, 7, 31], size=[4, 6, 1, 1], offset=1, signed=True) + + def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike): + super().__init__(opcode) + self.funct3 = funct3 + self.rs1 = rs1 + self.rs2 = rs2 + self.imm = imm class UTypeInstr(RISCVInstr): - def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike): - self.opcode = Value.cast(opcode) - self.rd = Value.cast(rd) - self.imm = Value.cast(imm) + rd = Field(base=7, size=5) + imm = Field(base=12, size=20, offset=12, signed=True) - def pack(self) -> Value: - return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:]) + def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike): + super().__init__(opcode) + self.rd = rd + self.imm = imm class JTypeInstr(RISCVInstr): - def __init__(self, opcode: ValueLike, rd: ValueLike, imm: ValueLike): - self.opcode = Value.cast(opcode) - self.rd = Value.cast(rd) - self.imm = Value.cast(imm) + rd = Field(base=7, size=5) + imm = Field(base=[21, 20, 12, 31], size=[10, 1, 8, 1], offset=1, signed=True) - def pack(self) -> Value: - return Cat(C(0b11, 2), self.opcode, self.rd, self.imm[12:20], self.imm[11], self.imm[1:11], self.imm[20]) + def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike): + super().__init__(opcode) + self.rd = rd + self.imm = imm class IllegalInstr(RISCVInstr): - def __init__(self): - pass + illegal = Field(base=7, size=25, static_value=Cat(1).replicate(25)) - def pack(self) -> Value: - return C(1).replicate(32) # Instructions with all bits set to 1 are reserved to be illegal. + def __init__(self): + super().__init__(opcode=Opcode.RESERVED) class EBreakInstr(ITypeInstr): diff --git a/requirements-dev.txt b/requirements-dev.txt index 1d9530305..fa39140f1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,8 @@ -r requirements.txt -black==23.3.0 +black==24.3.0 docutils==0.15.2 flake8==6.0.0 pep8-naming==0.13.3 -git+https://github.com/kristopher38/riscv-python-model@b5d0737#riscv-model markupsafe==2.0.1 myst-parser==0.18.0 numpydoc==1.5.0 diff --git a/test/cache/test_icache.py b/test/cache/test_icache.py index 3bd198c43..f53cff894 100644 --- a/test/cache/test_icache.py +++ b/test/cache/test_icache.py @@ -53,21 +53,25 @@ def elaborate(self, platform): @parameterized_class( - ("name", "isa_xlen", "block_size"), + ("name", "isa_xlen", "line_size", "fetch_block"), [ - ("blk_size16B_rv32i", 32, 4), - ("blk_size32B_rv32i", 32, 5), - ("blk_size32B_rv64i", 64, 5), - ("blk_size64B_rv32i", 32, 6), + ("line16B_block4B_rv32i", 32, 4, 2), + ("line32B_block8B_rv32i", 32, 5, 3), + ("line32B_block8B_rv64i", 64, 5, 3), + ("line64B_block16B_rv32i", 32, 6, 4), + ("line16B_block16B_rv32i", 32, 4, 4), ], ) class TestSimpleCommonBusCacheRefiller(TestCaseWithSimulator): isa_xlen: int - block_size: int + line_size: int + fetch_block: int def setUp(self) -> None: self.gen_params = GenParams( - test_core_config.replace(xlen=self.isa_xlen, icache_block_size_bits=self.block_size) + test_core_config.replace( + xlen=self.isa_xlen, icache_line_bytes_log=self.line_size, fetch_block_bytes_log=self.fetch_block + ) ) self.cp = self.gen_params.icache_params self.test_module = SimpleCommonBusCacheRefillerTestCircuit(self.gen_params) @@ -75,22 +79,24 @@ def setUp(self) -> None: random.seed(42) self.bad_addresses = set() + self.bad_fetch_blocks = set() self.mem = dict() self.requests = deque() for _ in range(100): # Make the address aligned to the beginning of a cache line - addr = random.randrange(2**self.gen_params.isa.xlen) & ~(self.cp.block_size_bytes - 1) + addr = random.randrange(2**self.gen_params.isa.xlen) & ~(self.cp.line_size_bytes - 1) self.requests.append(addr) if random.random() < 0.21: # Choose an address in this cache line to be erroneous - bad_addr = addr + random.randrange(self.cp.block_size_bytes) + bad_addr = addr + random.randrange(self.cp.line_size_bytes) # Make the address aligned to the machine word size bad_addr = bad_addr & ~(self.cp.word_width_bytes - 1) self.bad_addresses.add(bad_addr) + self.bad_fetch_blocks.add(bad_addr & ~(self.cp.fetch_block_bytes - 1)) def wishbone_slave(self): yield Passive() @@ -119,22 +125,26 @@ def refiller_process(self): req_addr = self.requests.pop() yield from self.test_module.start_refill.call(addr=req_addr) - for i in range(self.cp.words_in_block): + for i in range(self.cp.fetch_blocks_in_line): ret = yield from self.test_module.accept_refill.call() - cur_addr = req_addr + i * self.cp.word_width_bytes + cur_addr = req_addr + i * self.cp.fetch_block_bytes self.assertEqual(ret["addr"], cur_addr) - if cur_addr in self.bad_addresses: + if cur_addr in self.bad_fetch_blocks: self.assertEqual(ret["error"], 1) self.assertEqual(ret["last"], 1) break - self.assertEqual(ret["data"], self.mem[ret["addr"]]) + fetch_block = ret["fetch_block"] + for j in range(self.cp.words_in_fetch_block): + word = (fetch_block >> (j * self.cp.word_width)) & (2**self.cp.word_width - 1) + self.assertEqual(word, self.mem[cur_addr + j * self.cp.word_width_bytes]) + self.assertEqual(ret["error"], 0) - last = 1 if i == self.cp.words_in_block - 1 else 0 + last = 1 if i == self.cp.fetch_blocks_in_line - 1 else 0 self.assertEqual(ret["last"], last) def test(self): @@ -170,17 +180,20 @@ def elaborate(self, platform): @parameterized_class( - ("name", "isa_xlen"), + ("name", "isa_xlen", "fetch_block"), [ - ("rv32i", 32), - ("rv64i", 64), + ("rv32i", 32, 2), + ("rv64i", 64, 3), ], ) class TestICacheBypass(TestCaseWithSimulator): isa_xlen: str + fetch_block: int def setUp(self) -> None: - self.gen_params = GenParams(test_core_config.replace(xlen=self.isa_xlen)) + self.gen_params = GenParams( + test_core_config.replace(xlen=self.isa_xlen, fetch_block_bytes_log=self.fetch_block, icache_enable=False) + ) self.cp = self.gen_params.icache_params self.m = ICacheBypassTestCircuit(self.gen_params) @@ -231,7 +244,7 @@ def wishbone_slave(self): def user_process(self): while self.requests: - req_addr = self.requests.popleft() + req_addr = self.requests.popleft() & ~(self.cp.fetch_block_bytes - 1) yield from self.m.issue_req.call(addr=req_addr) while random.random() < 0.5: @@ -243,7 +256,11 @@ def user_process(self): self.assertTrue(ret["error"]) else: self.assertFalse(ret["error"]) - self.assertEqual(ret["instr"], self.mem[req_addr]) + + data = self.mem[req_addr] + if self.gen_params.isa.xlen == 64: + data |= self.mem[req_addr + 4] << 32 + self.assertEqual(ret["fetch_block"], data) while random.random() < 0.5: yield @@ -291,16 +308,18 @@ def elaborate(self, platform): @parameterized_class( - ("name", "isa_xlen", "block_size"), + ("name", "isa_xlen", "line_size", "fetch_block"), [ - ("blk_size16B_rv32i", 32, 4), - ("blk_size64B_rv32i", 32, 6), - ("blk_size32B_rv64i", 64, 5), + ("line16B_block8B_rv32i", 32, 4, 2), + ("line64B_block16B_rv32i", 32, 6, 4), + ("line32B_block16B_rv64i", 64, 5, 4), + ("line32B_block32B_rv64i", 64, 5, 5), ], ) class TestICache(TestCaseWithSimulator): isa_xlen: int - block_size: int + line_size: int + fetch_block: int def setUp(self) -> None: random.seed(42) @@ -321,7 +340,8 @@ def init_module(self, ways, sets) -> None: xlen=self.isa_xlen, icache_ways=ways, icache_sets_bits=exact_log2(sets), - icache_block_size_bits=self.block_size, + icache_line_bytes_log=self.line_size, + fetch_block_bytes_log=self.fetch_block, ) ) self.cp = self.gen_params.icache_params @@ -330,32 +350,32 @@ def init_module(self, ways, sets) -> None: @def_method_mock(lambda self: self.m.refiller.start_refill_mock) def start_refill_mock(self, addr): self.refill_requests.append(addr) - self.refill_word_cnt = 0 + self.refill_block_cnt = 0 self.refill_in_fly = True self.refill_addr = addr @def_method_mock(lambda self: self.m.refiller.accept_refill_mock, enable=lambda self: self.refill_in_fly) def accept_refill_mock(self): - addr = self.refill_addr + self.refill_word_cnt * self.cp.word_width_bytes - data = self.load_or_gen_mem(addr) - if self.gen_params.isa.xlen == 64: - data = self.load_or_gen_mem(addr + 4) << 32 | data + addr = self.refill_addr + self.refill_block_cnt * self.cp.fetch_block_bytes - self.refill_word_cnt += 1 + fetch_block = 0 + bad_addr = False + for i in range(0, self.cp.fetch_block_bytes, 4): + fetch_block |= self.load_or_gen_mem(addr + i) << (8 * i) + if addr + i in self.bad_addrs: + bad_addr = True - err = addr in self.bad_addrs - if self.gen_params.isa.xlen == 64: - err = err or (addr + 4) in self.bad_addrs + self.refill_block_cnt += 1 - last = self.refill_word_cnt == self.cp.words_in_block or err + last = self.refill_block_cnt == self.cp.fetch_blocks_in_line or bad_addr if last: self.refill_in_fly = False return { "addr": addr, - "data": data, - "error": err, + "fetch_block": fetch_block, + "error": bad_addr, "last": last, } @@ -380,13 +400,17 @@ def expect_resp(self, wait=False): self.assert_resp((yield from self.m.accept_res.get_outputs())) def assert_resp(self, resp: RecordIntDictRet): - addr = self.issued_requests.popleft() + addr = self.issued_requests.popleft() & ~(self.cp.fetch_block_bytes - 1) if (addr & ~((1 << self.cp.offset_bits) - 1)) in self.bad_cache_lines: self.assertTrue(resp["error"]) else: self.assertFalse(resp["error"]) - self.assertEqual(resp["instr"], self.mem[addr]) + fetch_block = 0 + for i in range(0, self.cp.fetch_block_bytes, 4): + fetch_block |= self.mem[addr + i] << (8 * i) + + self.assertEqual(resp["fetch_block"], fetch_block) def expect_refill(self, addr: int): self.assertEqual(self.refill_requests.popleft(), addr) @@ -407,13 +431,13 @@ def cache_user_process(): self.expect_refill(0x00010000) # Accesses to the same cache line shouldn't cause a cache miss - for i in range(self.cp.words_in_block): - yield from self.call_cache(0x00010000 + i * 4) + for i in range(self.cp.fetch_blocks_in_line): + yield from self.call_cache(0x00010000 + i * self.cp.fetch_block_bytes) self.assertEqual(len(self.refill_requests), 0) # Now go beyond the first cache line - yield from self.call_cache(0x00010000 + self.cp.block_size_bytes) - self.expect_refill(0x00010000 + self.cp.block_size_bytes) + yield from self.call_cache(0x00010000 + self.cp.line_size_bytes) + self.expect_refill(0x00010000 + self.cp.line_size_bytes) # Trigger cache aliasing yield from self.call_cache(0x00020000) @@ -422,14 +446,14 @@ def cache_user_process(): self.expect_refill(0x00010000) # Fill the whole cache - for i in range(0, self.cp.block_size_bytes * self.cp.num_of_sets, 4): + for i in range(0, self.cp.line_size_bytes * self.cp.num_of_sets, 4): yield from self.call_cache(i) for i in range(self.cp.num_of_sets): - self.expect_refill(i * self.cp.block_size_bytes) + self.expect_refill(i * self.cp.line_size_bytes) # Now do some accesses within the cached memory for i in range(50): - yield from self.call_cache(random.randrange(0, self.cp.block_size_bytes * self.cp.num_of_sets, 4)) + yield from self.call_cache(random.randrange(0, self.cp.line_size_bytes * self.cp.num_of_sets, 4)) self.assertEqual(len(self.refill_requests), 0) with self.run_simulation(self.m) as sim: @@ -460,7 +484,7 @@ def test_pipeline(self): def cache_process(): # Fill the cache for i in range(self.cp.num_of_sets): - addr = 0x00010000 + i * self.cp.block_size_bytes + addr = 0x00010000 + i * self.cp.line_size_bytes yield from self.call_cache(addr) self.expect_refill(addr) @@ -468,7 +492,7 @@ def cache_process(): # Create a stream of requests to ensure the pipeline is working yield from self.m.accept_res.enable() - for i in range(0, self.cp.num_of_sets * self.cp.block_size_bytes, 4): + for i in range(0, self.cp.num_of_sets * self.cp.line_size_bytes, 4): addr = 0x00010000 + i self.issued_requests.append(addr) @@ -488,7 +512,7 @@ def cache_process(): yield from self.tick(5) # Check how the cache handles queuing the requests - yield from self.send_req(addr=0x00010000 + 3 * self.cp.block_size_bytes) + yield from self.send_req(addr=0x00010000 + 3 * self.cp.line_size_bytes) yield from self.send_req(addr=0x00010004) # Wait a few cycles. There are two requests queued @@ -508,7 +532,7 @@ def cache_process(): # Schedule two requests, the first one causing a cache miss yield from self.send_req(addr=0x00020000) - yield from self.send_req(addr=0x00010000 + self.cp.block_size_bytes) + yield from self.send_req(addr=0x00010000 + self.cp.line_size_bytes) yield from self.m.accept_res.enable() @@ -522,7 +546,7 @@ def cache_process(): # Schedule two requests, the second one causing a cache miss yield from self.send_req(addr=0x00020004) - yield from self.send_req(addr=0x00030000 + self.cp.block_size_bytes) + yield from self.send_req(addr=0x00030000 + self.cp.line_size_bytes) yield from self.m.accept_res.enable() @@ -536,7 +560,7 @@ def cache_process(): # Schedule two requests, both causing a cache miss yield from self.send_req(addr=0x00040000) - yield from self.send_req(addr=0x00050000 + self.cp.block_size_bytes) + yield from self.send_req(addr=0x00050000 + self.cp.line_size_bytes) yield from self.m.accept_res.enable() @@ -556,14 +580,14 @@ def cache_process(): # Fill the whole cache for s in range(self.cp.num_of_sets): for w in range(self.cp.num_of_ways): - addr = w * 0x00010000 + s * self.cp.block_size_bytes + addr = w * 0x00010000 + s * self.cp.line_size_bytes yield from self.call_cache(addr) self.expect_refill(addr) # Everything should be in the cache for s in range(self.cp.num_of_sets): for w in range(self.cp.num_of_ways): - addr = w * 0x00010000 + s * self.cp.block_size_bytes + addr = w * 0x00010000 + s * self.cp.line_size_bytes yield from self.call_cache(addr) self.assertEqual(len(self.refill_requests), 0) @@ -573,7 +597,7 @@ def cache_process(): # The cache should be empty for s in range(self.cp.num_of_sets): for w in range(self.cp.num_of_ways): - addr = w * 0x00010000 + s * self.cp.block_size_bytes + addr = w * 0x00010000 + s * self.cp.line_size_bytes yield from self.call_cache(addr) self.expect_refill(addr) @@ -605,7 +629,7 @@ def cache_process(): yield # Schedule two requests and then flush - yield from self.send_req(0x00000000 + self.cp.block_size_bytes) + yield from self.send_req(0x00000000 + self.cp.line_size_bytes) yield from self.send_req(0x00010000) yield from self.m.flush_cache.call() self.mem[0x00010000] = random.randrange(2**self.gen_params.isa.ilen) @@ -613,7 +637,7 @@ def cache_process(): # And accept the results self.assert_resp((yield from self.m.accept_res.call())) self.assert_resp((yield from self.m.accept_res.call())) - self.expect_refill(0x00000000 + self.cp.block_size_bytes) + self.expect_refill(0x00000000 + self.cp.line_size_bytes) # Just make sure that the line is truly flushed yield from self.call_cache(0x00010000) @@ -629,7 +653,7 @@ def cache_process(): self.add_bad_addr(0x00010000) # Bad addr at the beggining of the line self.add_bad_addr(0x00020008) # Bad addr in the middle of the line self.add_bad_addr( - 0x00030000 + self.cp.block_size_bytes - self.cp.word_width_bytes + 0x00030000 + self.cp.line_size_bytes - self.cp.word_width_bytes ) # Bad addr at the end of the line yield from self.call_cache(0x00010008) @@ -691,6 +715,30 @@ def cache_process(): yield from self.expect_resp(wait=True) yield yield from self.m.accept_res.disable() + yield + + # The second request will cause an error + yield from self.send_req(addr=0x00021004) + yield from self.send_req(addr=0x00030000) + + yield from self.tick(10) + + # Accept the first response + yield from self.m.accept_res.enable() + yield from self.expect_resp(wait=True) + yield + + # Wait before accepting the second response + yield from self.m.accept_res.disable() + yield from self.tick(10) + yield from self.m.accept_res.enable() + yield from self.expect_resp(wait=True) + + yield + + # This request should not cause an error + yield from self.send_req(addr=0x00011000) + yield from self.expect_resp(wait=True) with self.run_simulation(self.m) as sim: sim.add_sync_process(cache_process) @@ -698,7 +746,7 @@ def cache_process(): def test_random(self): self.init_module(4, 8) - max_addr = 16 * self.cp.block_size_bytes * self.cp.num_of_sets + max_addr = 16 * self.cp.line_size_bytes * self.cp.num_of_sets iterations = 1000 for i in range(0, max_addr, 4): diff --git a/test/frontend/test_fetch.py b/test/frontend/test_fetch.py index b9ff1388c..3684f7cad 100644 --- a/test/frontend/test_fetch.py +++ b/test/frontend/test_fetch.py @@ -84,7 +84,7 @@ def cache_process(self): data |= 0b1100000 data &= ~0b0010000 # but not system - self.output_q.append({"instr": data, "error": 0}) + self.output_q.append({"fetch_block": data, "error": 0}) # Speculative fetch. Skip, because this instruction shouldn't be executed. if addr != next_pc: @@ -229,7 +229,7 @@ def get_mem_or_random(addr): data = (get_mem_or_random(req_addr + 2) << 16) | get_mem_or_random(req_addr) err = (req_addr in self.memerr) or (req_addr + 2 in self.memerr) - self.output_q.append({"instr": data, "error": err}) + self.output_q.append({"fetch_block": data, "error": err}) @def_method_mock(lambda self: self.icache.issue_req_io, enable=lambda self: len(self.input_q) < 2, sched_prio=1) def issue_req_mock(self, addr): diff --git a/test/frontend/test_rvc.py b/test/frontend/test_rvc.py index 0b099f751..8d8fba5a5 100644 --- a/test/frontend/test_rvc.py +++ b/test/frontend/test_rvc.py @@ -25,17 +25,17 @@ # c.addi x2, -28 ( 0x1111, - ITypeInstr(opcode=Opcode.OP_IMM, rd=Registers.X2, funct3=Funct3.ADD, rs1=Registers.X2, imm=C(-28, 12)), + ITypeInstr(opcode=Opcode.OP_IMM, rd=Registers.X2, funct3=Funct3.ADD, rs1=Registers.X2, imm=-28), ), # c.li x31, -7 ( 0x5FE5, - ITypeInstr(opcode=Opcode.OP_IMM, rd=Registers.X31, funct3=Funct3.ADD, rs1=Registers.ZERO, imm=C(-7, 12)), + ITypeInstr(opcode=Opcode.OP_IMM, rd=Registers.X31, funct3=Funct3.ADD, rs1=Registers.ZERO, imm=-7), ), # c.addi16sp 496 (0x617D, ITypeInstr(opcode=Opcode.OP_IMM, rd=Registers.SP, funct3=Funct3.ADD, rs1=Registers.SP, imm=496)), # c.lui x7, -3 - (0x73F5, UTypeInstr(opcode=Opcode.LUI, rd=Registers.X7, imm=C(-3, 20) << 12)), + (0x73F5, UTypeInstr(opcode=Opcode.LUI, rd=Registers.X7, imm=Cat(C(0, 12), C(-3, 20)))), # c.srli x10, 3 ( 0x810D, @@ -44,7 +44,7 @@ rd=Registers.X10, funct3=Funct3.SR, rs1=Registers.X10, - rs2=C(3, 5), + rs2=Registers.X3, funct7=Funct7.SL, ), ), @@ -56,7 +56,7 @@ rd=Registers.X12, funct3=Funct3.SR, rs1=Registers.X12, - rs2=C(8, 5), + rs2=Registers.X8, funct7=Funct7.SA, ), ), @@ -111,16 +111,16 @@ ), ), # c.j 2012 - (0xAFF1, JTypeInstr(opcode=Opcode.JAL, rd=Registers.ZERO, imm=C(2012, 21))), + (0xAFF1, JTypeInstr(opcode=Opcode.JAL, rd=Registers.ZERO, imm=2012)), # c.beqz x8, -6 ( 0xDC6D, - BTypeInstr(opcode=Opcode.BRANCH, imm=C(-6, 13), funct3=Funct3.BEQ, rs1=Registers.X8, rs2=Registers.ZERO), + BTypeInstr(opcode=Opcode.BRANCH, imm=-6, funct3=Funct3.BEQ, rs1=Registers.X8, rs2=Registers.ZERO), ), # c.bnez x15, 20 ( 0xEB91, - BTypeInstr(opcode=Opcode.BRANCH, imm=C(20, 13), funct3=Funct3.BNE, rs1=Registers.X15, rs2=Registers.ZERO), + BTypeInstr(opcode=Opcode.BRANCH, imm=20, funct3=Funct3.BNE, rs1=Registers.X15, rs2=Registers.ZERO), ), # c.slli x13, 31 ( @@ -130,18 +130,16 @@ rd=Registers.X13, funct3=Funct3.SLL, rs1=Registers.X13, - rs2=C(31, 5), + rs2=Registers.X31, funct7=Funct7.SL, ), ), # c.lwsp x2, 4 - (0x4112, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X2, funct3=Funct3.W, rs1=Registers.SP, imm=C(4, 12))), + (0x4112, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X2, funct3=Funct3.W, rs1=Registers.SP, imm=4)), # c.jr x30 ( 0x8F02, - ITypeInstr( - opcode=Opcode.JALR, rd=Registers.ZERO, funct3=Funct3.JALR, rs1=Registers.X30, imm=C(0).replicate(12) - ), + ITypeInstr(opcode=Opcode.JALR, rd=Registers.ZERO, funct3=Funct3.JALR, rs1=Registers.X30, imm=0), ), # c.mv x2, x26 ( @@ -170,7 +168,7 @@ ), ), # c.swsp x31, 20 - (0xCA7E, STypeInstr(opcode=Opcode.STORE, imm=C(20, 12), funct3=Funct3.W, rs1=Registers.SP, rs2=Registers.X31)), + (0xCA7E, STypeInstr(opcode=Opcode.STORE, imm=20, funct3=Funct3.W, rs1=Registers.SP, rs2=Registers.X31)), ] RV32_TESTS = [ @@ -179,9 +177,9 @@ # c.sd x14, 0(x13) (0xE298, IllegalInstr()), # c.jal 40 - (0x2025, JTypeInstr(opcode=Opcode.JAL, rd=Registers.RA, imm=C(40, 21))), + (0x2025, JTypeInstr(opcode=Opcode.JAL, rd=Registers.RA, imm=40)), # c.jal -412 - (0x3595, JTypeInstr(opcode=Opcode.JAL, rd=Registers.RA, imm=C(-412, 21))), + (0x3595, JTypeInstr(opcode=Opcode.JAL, rd=Registers.RA, imm=-412)), # c.srli x10, 32 (0x9101, IllegalInstr()), # c.srai x12, 40 @@ -196,13 +194,13 @@ RV64_TESTS = [ # c.ld x8, 8(x9) - (0x6480, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X8, funct3=Funct3.D, rs1=Registers.X9, imm=C(8, 12))), + (0x6480, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X8, funct3=Funct3.D, rs1=Registers.X9, imm=8)), # c.sd x14, 0(x13) - (0xE298, STypeInstr(opcode=Opcode.STORE, imm=C(0, 12), funct3=Funct3.D, rs1=Registers.X13, rs2=Registers.X14)), + (0xE298, STypeInstr(opcode=Opcode.STORE, imm=0, funct3=Funct3.D, rs1=Registers.X13, rs2=Registers.X14)), # c.addiw x13, -12, ( 0x36D1, - ITypeInstr(opcode=Opcode.OP_IMM_32, rd=Registers.X13, funct3=Funct3.ADD, rs1=Registers.X13, imm=C(-12, 12)), + ITypeInstr(opcode=Opcode.OP_IMM_32, rd=Registers.X13, funct3=Funct3.ADD, rs1=Registers.X13, imm=-12), ), # c.srli x10, 32 ( @@ -212,7 +210,7 @@ rd=Registers.X10, funct3=Funct3.SR, rs1=Registers.X10, - rs2=C(0, 5), + rs2=Registers.X0, funct7=Funct7.SL | 1, ), ), @@ -224,7 +222,7 @@ rd=Registers.X12, funct3=Funct3.SR, rs1=Registers.X12, - rs2=C(8, 5), + rs2=Registers.X8, funct7=Funct7.SA | 1, ), ), @@ -260,14 +258,14 @@ rd=Registers.X13, funct3=Funct3.SLL, rs1=Registers.X13, - rs2=C(31, 5), + rs2=Registers.X31, funct7=Funct7.SL | 1, ), ), # c.ldsp x29, 40 - (0x7EA2, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X29, funct3=Funct3.D, rs1=Registers.SP, imm=C(40, 12))), + (0x7EA2, ITypeInstr(opcode=Opcode.LOAD, rd=Registers.X29, funct3=Funct3.D, rs1=Registers.SP, imm=40)), # c.sdsp x4, 8 - (0xE412, STypeInstr(opcode=Opcode.STORE, imm=C(8, 12), funct3=Funct3.D, rs1=Registers.SP, rs2=Registers.X4)), + (0xE412, STypeInstr(opcode=Opcode.STORE, imm=8, funct3=Funct3.D, rs1=Registers.SP, rs2=Registers.X4)), ] @@ -280,7 +278,9 @@ class TestInstrDecompress(TestCaseWithSimulator): test_cases: list[tuple[int, ValueLike]] def test(self): - self.gen_params = GenParams(test_core_config.replace(compressed=True, xlen=self.isa_xlen)) + self.gen_params = GenParams( + test_core_config.replace(compressed=True, xlen=self.isa_xlen, fetch_block_bytes_log=3) + ) self.m = InstrDecompress(self.gen_params) def process(): diff --git a/test/lsu/test_dummylsu.py b/test/lsu/test_dummylsu.py index 4211720a6..776f0e2cd 100644 --- a/test/lsu/test_dummylsu.py +++ b/test/lsu/test_dummylsu.py @@ -173,9 +173,9 @@ def generate_instr(self, max_reg_val, max_imm_val): self.exception_queue.append( { "rob_id": rob_id, - "cause": ExceptionCause.LOAD_ADDRESS_MISALIGNED - if misaligned - else ExceptionCause.LOAD_ACCESS_FAULT, + "cause": ( + ExceptionCause.LOAD_ADDRESS_MISALIGNED if misaligned else ExceptionCause.LOAD_ACCESS_FAULT + ), "pc": 0, } ) diff --git a/test/params/test_instr.py b/test/params/test_instr.py new file mode 100644 index 000000000..0ed97e19c --- /dev/null +++ b/test/params/test_instr.py @@ -0,0 +1,63 @@ +import unittest +from typing import Sequence + +from amaranth import * + +from coreblocks.params.instr import * +from coreblocks.frontend.decoder.isa import * + + +class InstructionTest(unittest.TestCase): + def do_run(self, test_cases: Sequence[tuple[RISCVInstr, int]]): + for instr, raw_instr in test_cases: + self.assertEqual(instr.encode(), raw_instr) + + def test_r_type(self): + test_cases = [ + (RTypeInstr(opcode=Opcode.OP, rd=21, funct3=Funct3.AND, rs1=10, rs2=31, funct7=Funct7.AND), 0x1F57AB3), + ] + + self.do_run(test_cases) + + def test_i_type(self): + test_cases = [ + (ITypeInstr(opcode=Opcode.LOAD_FP, rd=22, funct3=Funct3.D, rs1=10, imm=2047), 0x7FF53B07), + (ITypeInstr(opcode=Opcode.LOAD_FP, rd=22, funct3=Funct3.D, rs1=10, imm=-2048), 0x80053B07), + ] + + self.do_run(test_cases) + + def test_s_type(self): + test_cases = [ + (STypeInstr(opcode=Opcode.STORE_FP, imm=2047, funct3=Funct3.D, rs1=31, rs2=0), 0x7E0FBFA7), + (STypeInstr(opcode=Opcode.STORE_FP, imm=-2048, funct3=Funct3.D, rs1=5, rs2=13), 0x80D2B027), + ] + + self.do_run(test_cases) + + def test_b_type(self): + test_cases = [ + (BTypeInstr(opcode=Opcode.BRANCH, imm=4094, funct3=Funct3.BNE, rs1=10, rs2=0), 0x7E051FE3), + (BTypeInstr(opcode=Opcode.BRANCH, imm=-4096, funct3=Funct3.BEQ, rs1=31, rs2=4), 0x804F8063), + ] + + self.do_run(test_cases) + + def test_u_type(self): + test_cases = [ + (UTypeInstr(opcode=Opcode.LUI, rd=10, imm=3102 << 12), 0xC1E537), + (UTypeInstr(opcode=Opcode.LUI, rd=31, imm=1048575 << 12), 0xFFFFFFB7), + ] + + self.do_run(test_cases) + + def test_j_type(self): + test_cases = [ + (JTypeInstr(opcode=Opcode.JAL, rd=0, imm=0), 0x6F), + (JTypeInstr(opcode=Opcode.JAL, rd=0, imm=2), 0x20006F), + (JTypeInstr(opcode=Opcode.JAL, rd=10, imm=1048572), 0x7FDFF56F), + (JTypeInstr(opcode=Opcode.JAL, rd=3, imm=-230), 0xF1BFF1EF), + (JTypeInstr(opcode=Opcode.JAL, rd=15, imm=-1048576), 0x800007EF), + ] + + self.do_run(test_cases) diff --git a/test/regression/memory.py b/test/regression/memory.py index 70b8a9496..a34ef764d 100644 --- a/test/regression/memory.py +++ b/test/regression/memory.py @@ -164,9 +164,9 @@ def load_segment(segment: Segment, *, disable_write_protection: bool = False) -> config = CoreConfiguration() if flags_raw & P_FLAGS.PF_X: # align instruction section to full icache lines - align_bits = config.icache_block_size_bits + align_bits = config.icache_line_bytes_log # workaround for fetching/stalling issue - extend_end = 2**config.icache_block_size_bits + extend_end = 2**config.icache_line_bytes_log else: align_bits = 0 extend_end = 0 diff --git a/test/test_core.py b/test/test_core.py index a2cfd1d88..7bb939ac8 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -7,7 +7,9 @@ from transactron.testing import TestCaseWithSimulator, TestbenchIO from coreblocks.core import Core +from coreblocks.frontend.decoder import Opcode, Funct3 from coreblocks.params import GenParams +from coreblocks.params.instr import * from coreblocks.params.configurations import CoreConfiguration, basic_core_config, full_core_config from coreblocks.peripherals.wishbone import WishboneSignature, WishboneMemorySlave @@ -16,10 +18,6 @@ import subprocess import tempfile from parameterized import parameterized_class -from riscvmodel.insn import ( - InstructionADDI, - InstructionLUI, -) class CoreTestElaboratable(Elaboratable): @@ -38,7 +36,7 @@ def elaborate(self, platform): wb_data_bus = WishboneSignature(self.gen_params.wb_params).create() # Align the size of the memory to the length of a cache line. - instr_mem_depth = align_to_power_of_two(len(self.instr_mem), self.gen_params.icache_params.block_size_bits) + instr_mem_depth = align_to_power_of_two(len(self.instr_mem), self.gen_params.icache_params.line_bytes_log) self.wb_mem_slave = WishboneMemorySlave( wb_params=self.gen_params.wb_params, width=32, depth=instr_mem_depth, init=self.instr_mem ) @@ -81,8 +79,10 @@ def push_register_load_imm(self, reg_id, val): if val & 0x800: lui_imm = (lui_imm + 1) & (0xFFFFF) - yield from self.push_instr(InstructionLUI(reg_id, lui_imm).encode()) - yield from self.push_instr(InstructionADDI(reg_id, reg_id, addi_imm).encode()) + yield from self.push_instr(UTypeInstr(opcode=Opcode.LUI, rd=reg_id, imm=lui_imm << 12).encode()) + yield from self.push_instr( + ITypeInstr(opcode=Opcode.OP_IMM, rd=reg_id, funct3=Funct3.ADD, rs1=reg_id, imm=addi_imm).encode() + ) class TestCoreAsmSourceBase(TestCoreBase): diff --git a/test/transactron/test_connectors.py b/test/transactron/test_connectors.py new file mode 100644 index 000000000..2903397b6 --- /dev/null +++ b/test/transactron/test_connectors.py @@ -0,0 +1,46 @@ +import random +from parameterized import parameterized_class + +from amaranth.sim import Settle + +from transactron.lib import StableSelectingNetwork +from transactron.testing import TestCaseWithSimulator + + +@parameterized_class( + ("n"), + [(2,), (3,), (7,), (8,)], +) +class TestStableSelectingNetwork(TestCaseWithSimulator): + n: int + + def test(self): + m = StableSelectingNetwork(self.n, [("data", 8)]) + + random.seed(42) + + def process(): + for _ in range(100): + inputs = [random.randrange(2**8) for _ in range(self.n)] + valids = [random.randrange(2) for _ in range(self.n)] + total = sum(valids) + + expected_output_prefix = [] + for i in range(self.n): + yield m.valids[i].eq(valids[i]) + yield m.inputs[i].data.eq(inputs[i]) + + if valids[i]: + expected_output_prefix.append(inputs[i]) + + yield Settle() + + for i in range(total): + out = yield m.outputs[i].data + self.assertEqual(out, expected_output_prefix[i]) + + self.assertEqual((yield m.output_cnt), total) + yield + + with self.run_simulation(m) as sim: + sim.add_sync_process(process) diff --git a/test/transactron/test_metrics.py b/test/transactron/test_metrics.py index 6b0e4f738..a8af19af9 100644 --- a/test/transactron/test_metrics.py +++ b/test/transactron/test_metrics.py @@ -1,6 +1,9 @@ import json import random import queue +from typing import Type +from enum import IntFlag, IntEnum, auto, Enum + from parameterized import parameterized_class from amaranth import * @@ -139,6 +142,85 @@ def test_process(): sim.add_sync_process(test_process) +class OneHotEnum(IntFlag): + ADD = auto() + XOR = auto() + OR = auto() + + +class PlainIntEnum(IntEnum): + TEST_1 = auto() + TEST_2 = auto() + TEST_3 = auto() + + +class TaggedCounterCircuit(Elaboratable): + def __init__(self, tags: range | Type[Enum] | list[int]): + self.counter = TaggedCounter("counter", "", tags=tags) + + self.cond = Signal() + self.tag = Signal(self.counter.tag_width) + + def elaborate(self, platform): + m = TModule() + + m.submodules.counter = self.counter + + with Transaction().body(m): + self.counter.incr(m, self.tag, cond=self.cond) + + return m + + +class TestTaggedCounter(TestCaseWithSimulator): + def setUp(self) -> None: + random.seed(42) + + def do_test_enum(self, tags: range | Type[Enum] | list[int], tag_values: list[int]): + m = TaggedCounterCircuit(tags) + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + counts: dict[int, int] = {} + for i in tag_values: + counts[i] = 0 + + def test_process(): + for _ in range(200): + for i in tag_values: + self.assertEqual(counts[i], (yield m.counter.counters[i].value)) + + tag = random.choice(list(tag_values)) + + yield m.cond.eq(1) + yield m.tag.eq(tag) + yield + yield m.cond.eq(0) + yield + + counts[tag] += 1 + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) + + def test_one_hot_enum(self): + self.do_test_enum(OneHotEnum, [e.value for e in OneHotEnum]) + + def test_plain_int_enum(self): + self.do_test_enum(PlainIntEnum, [e.value for e in PlainIntEnum]) + + def test_negative_range(self): + r = range(-10, 15, 3) + self.do_test_enum(r, list(r)) + + def test_positive_range(self): + r = range(0, 30, 2) + self.do_test_enum(r, list(r)) + + def test_value_list(self): + values = [-2137, 2, 4, 8, 42] + self.do_test_enum(values, values) + + class ExpHistogramCircuit(Elaboratable): def __init__(self, bucket_cnt: int, sample_width: int): self.sample_width = sample_width diff --git a/transactron/lib/connectors.py b/transactron/lib/connectors.py index 511cf6248..e35d969db 100644 --- a/transactron/lib/connectors.py +++ b/transactron/lib/connectors.py @@ -11,6 +11,7 @@ "Connect", "ConnectTrans", "ManyToOneConnectTrans", + "StableSelectingNetwork", ] @@ -275,3 +276,82 @@ def elaborate(self, platform): ) return m + + +class StableSelectingNetwork(Elaboratable): + """A network that groups inputs with a valid bit set. + + The circuit takes `n` inputs with a valid signal each and + on the output returns a grouped and consecutive sequence of the provided + input signals. The order of valid inputs is preserved. + + For example for input (0 is an invalid input): + 0, a, 0, d, 0, 0, e + + The circuit will return: + a, d, e, 0, 0, 0, 0 + + The circuit uses a divide and conquer algorithm. + The recursive call takes two bit vectors and each of them + is already properly sorted, for example: + v1 = [a, b, 0, 0]; v2 = [c, d, e, 0] + + Now by shifting left v2 and merging it with v1, we get the result: + v = [a, b, c, d, e, 0, 0, 0] + + Thus, the network has depth log_2(n). + + """ + + def __init__(self, n: int, layout: MethodLayout): + self.n = n + self.layout = from_method_layout(layout) + + self.inputs = [Signal(self.layout) for _ in range(n)] + self.valids = [Signal() for _ in range(n)] + + self.outputs = [Signal(self.layout) for _ in range(n)] + self.output_cnt = Signal(range(n + 1)) + + def elaborate(self, platform): + m = TModule() + + current_level = [] + for i in range(self.n): + current_level.append((Array([self.inputs[i]]), self.valids[i])) + + # Create the network using the bottom-up approach. + while len(current_level) >= 2: + next_level = [] + while len(current_level) >= 2: + a, cnt_a = current_level.pop(0) + b, cnt_b = current_level.pop(0) + + total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1) + m.d.comb += total_cnt.eq(cnt_a + cnt_b) + + total_len = len(a) + len(b) + merged = Array(Signal(self.layout) for _ in range(total_len)) + + for i in range(len(a)): + m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i])) + for i in range(len(b)): + m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a])) + + next_level.append((merged, total_cnt)) + + # If we had an odd number of elements on the current level, + # move the item left to the next level. + if len(current_level) == 1: + next_level.append(current_level.pop(0)) + + current_level = next_level + + last_level, total_cnt = current_level.pop(0) + + for i in range(self.n): + m.d.comb += self.outputs[i].eq(last_level[i]) + + m.d.comb += self.output_cnt.eq(total_cnt) + + return m diff --git a/transactron/lib/fifo.py b/transactron/lib/fifo.py index 92ac0f7bb..24cacfadc 100644 --- a/transactron/lib/fifo.py +++ b/transactron/lib/fifo.py @@ -13,6 +13,9 @@ class BasicFifo(Elaboratable): read: Method Reads from the FIFO. Accepts an empty argument, returns a structure. Ready only if the FIFO is not empty. + peek: Method + Returns the element at the front (but not delete). Ready only if the FIFO + is not empty. The method is nonexclusive. write: Method Writes to the FIFO. Accepts a structure, returns empty result. Ready only if the FIFO is not full. @@ -40,6 +43,7 @@ def __init__(self, layout: MethodLayout, depth: int, *, src_loc: int | SrcLoc = src_loc = get_src_loc(src_loc) self.read = Method(o=self.layout, src_loc=src_loc) + self.peek = Method(o=self.layout, nonexclusive=True, src_loc=src_loc) self.write = Method(i=self.layout, src_loc=src_loc) self.clear = Method(src_loc=src_loc) self.head = Signal(from_method_layout(layout)) @@ -93,6 +97,10 @@ def _() -> ValueLike: m.d.sync += self.read_idx.eq(next_read_idx) return self.head + @def_method(m, self.peek, self.read_ready) + def _() -> ValueLike: + return self.head + @def_method(m, self.clear) def _() -> None: m.d.sync += self.read_idx.eq(0) diff --git a/transactron/lib/metrics.py b/transactron/lib/metrics.py index b7e36a86c..17921e619 100644 --- a/transactron/lib/metrics.py +++ b/transactron/lib/metrics.py @@ -1,14 +1,14 @@ from dataclasses import dataclass, field from dataclasses_json import dataclass_json -from typing import Optional +from typing import Optional, Type from abc import ABC +from enum import Enum from amaranth import * -from amaranth.utils import bits_for +from amaranth.utils import bits_for, ceil_log2, exact_log2 -from transactron.utils import ValueLike +from transactron.utils import ValueLike, OneHotSwitchDynamic, SignalBundle from transactron import Method, def_method, TModule -from transactron.utils import SignalBundle from transactron.lib import FIFO, AsyncMemoryBank, logging from transactron.utils.dependencies import ListKey, DependencyContext, SimpleKey @@ -17,6 +17,7 @@ "MetricModel", "HwMetric", "HwCounter", + "TaggedCounter", "HwExpHistogram", "FIFOLatencyMeasurer", "TaggedLatencyMeasurer", @@ -231,6 +232,127 @@ def incr(self, m: TModule, *, cond: ValueLike = C(1)): self._incr(m) +class TaggedCounter(Elaboratable, HwMetric): + """Hardware Tagged Counter + + Like HwCounter, but contains multiple counters, each with its own tag. + At a time a single counter can be increased and the value of the tag + can be provided dynamically. The type of the tag can be either an int + enum, a range or a list of integers (negative numbers are ok). + + Internally, it detects if tag values can be one-hot encoded and if so, + it generates more optimized circuit. + + Attributes + ---------- + tag_width: int + The length of the signal holding a tag value. + one_hot: bool + Whether tag values can be one-hot encoded. + counters: dict[int, HwMetricRegisters] + Mapping from a tag value to a register holding a counter for that tag. + """ + + def __init__( + self, + fully_qualified_name: str, + description: str = "", + *, + tags: range | Type[Enum] | list[int], + registers_width: int = 32, + ): + """ + Parameters + ---------- + fully_qualified_name: str + The fully qualified name of the metric. + description: str + A human-readable description of the metric's functionality. + tags: range | Type[Enum] | list[int] + Tag values. + registers_width: int + Width of the underlying registers. Defaults to 32 bits. + """ + + super().__init__(fully_qualified_name, description) + + if isinstance(tags, range) or isinstance(tags, list): + counters_meta = [(i, f"{i}") for i in tags] + else: + counters_meta = [(i.value, i.name) for i in tags] + + values = [value for value, _ in counters_meta] + self.tag_width = max(bits_for(max(values)), bits_for(min(values))) + + self.one_hot = True + negative_values = False + for value in values: + if value < 0: + self.one_hot = False + negative_values = True + break + + log = ceil_log2(value) + if 2**log != value: + self.one_hot = False + + self._incr = Method(i=[("tag", Shape(self.tag_width, signed=negative_values))]) + + self.counters: dict[int, HwMetricRegister] = {} + for tag_value, name in counters_meta: + value_str = ("1<<" + str(exact_log2(tag_value))) if self.one_hot else str(tag_value) + description = f"the counter for tag {name} (value={value_str})" + + self.counters[tag_value] = HwMetricRegister( + name, + registers_width, + description, + ) + + self.add_registers(list(self.counters.values())) + + def elaborate(self, platform): + if not self.metrics_enabled(): + return TModule() + + m = TModule() + + @def_method(m, self._incr) + def _(tag): + if self.one_hot: + sorted_tags = sorted(list(self.counters.keys())) + for i in OneHotSwitchDynamic(m, tag): + counter = self.counters[sorted_tags[i]] + m.d.sync += counter.value.eq(counter.value + 1) + else: + for tag_value, counter in self.counters.items(): + with m.If(tag == tag_value): + m.d.sync += counter.value.eq(counter.value + 1) + + return m + + def incr(self, m: TModule, tag: ValueLike, *, cond: ValueLike = C(1)): + """ + Increases the counter of a given tag by 1. + + Should be called in the body of either a transaction or a method. + + Parameters + ---------- + m: TModule + Transactron module + tag: ValueLike + The tag of the counter. + cond: ValueLike + When set to high, the counter will be increased. By default set to high. + """ + if not self.metrics_enabled(): + return + + with m.If(cond): + self._incr(m, tag) + + class HwExpHistogram(Elaboratable, HwMetric): """Hardware Exponential Histogram diff --git a/transactron/lib/reqres.py b/transactron/lib/reqres.py index f9aeb6e06..a3f6e2908 100644 --- a/transactron/lib/reqres.py +++ b/transactron/lib/reqres.py @@ -1,7 +1,7 @@ from amaranth import * from ..core import * from ..utils import SrcLoc, get_src_loc, MethodLayout -from .connectors import Forwarder, FIFO +from .connectors import Forwarder from transactron.lib import BasicFifo from amaranth.utils import * @@ -39,6 +39,8 @@ class ArgumentsToResultsZipper(Elaboratable): Attributes ---------- + peek_arg: Method + A nonexclusive method to read (but not delete) the head of the arg queue. write_args: Method Method to write arguments with `args_layout` format to 2-FIFO. write_results: Method @@ -65,6 +67,7 @@ def __init__(self, args_layout: MethodLayout, results_layout: MethodLayout, src_ self.args_layout = args_layout self.output_layout = [("args", self.args_layout), ("results", results_layout)] + self.peek_arg = Method(o=self.args_layout, nonexclusive=True, src_loc=self.src_loc) self.write_args = Method(i=self.args_layout, src_loc=self.src_loc) self.write_results = Method(i=self.results_layout, src_loc=self.src_loc) self.read = Method(o=self.output_layout, src_loc=self.src_loc) @@ -72,7 +75,7 @@ def __init__(self, args_layout: MethodLayout, results_layout: MethodLayout, src_ def elaborate(self, platform): m = TModule() - fifo = FIFO(self.args_layout, depth=2, src_loc=self.src_loc) + fifo = BasicFifo(self.args_layout, depth=2, src_loc=self.src_loc) forwarder = Forwarder(self.results_layout, src_loc=self.src_loc) m.submodules.fifo = fifo @@ -92,6 +95,8 @@ def _(): results = forwarder.read(m) return {"args": args, "results": results} + self.peek_arg.proxy(m, fifo.peek) + return m diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py index f874cea2c..5e9b1a6b0 100644 --- a/transactron/lib/transformers.py +++ b/transactron/lib/transformers.py @@ -60,8 +60,7 @@ def use(self, m: ModuleLike): class Unifier(Transformer, Protocol): method: Method - def __init__(self, targets: list[Method]): - ... + def __init__(self, targets: list[Method]): ... class MethodMap(Elaboratable, Transformer): diff --git a/transactron/utils/_typing.py b/transactron/utils/_typing.py index 32497c7d5..e8e3152b9 100644 --- a/transactron/utils/_typing.py +++ b/transactron/utils/_typing.py @@ -86,17 +86,13 @@ # Protocols for Amaranth classes class _ModuleBuilderDomainsLike(Protocol): - def __getattr__(self, name: str) -> "_ModuleBuilderDomain": - ... + def __getattr__(self, name: str) -> "_ModuleBuilderDomain": ... - def __getitem__(self, name: str) -> "_ModuleBuilderDomain": - ... + def __getitem__(self, name: str) -> "_ModuleBuilderDomain": ... - def __setattr__(self, name: str, value: "_ModuleBuilderDomain") -> None: - ... + def __setattr__(self, name: str, value: "_ModuleBuilderDomain") -> None: ... - def __setitem__(self, name: str, value: "_ModuleBuilderDomain") -> None: - ... + def __setitem__(self, name: str, value: "_ModuleBuilderDomain") -> None: ... _T_ModuleBuilderDomains = TypeVar("_T_ModuleBuilderDomains", bound=_ModuleBuilderDomainsLike) @@ -127,80 +123,59 @@ def Default(self) -> AbstractContextManager[None]: # noqa: N802 def FSM( # noqa: N802 self, reset: Optional[str] = ..., domain: str = ..., name: str = ... - ) -> AbstractContextManager["amaranth.hdl._dsl.FSM"]: - ... + ) -> AbstractContextManager["amaranth.hdl._dsl.FSM"]: ... def State(self, name: str) -> AbstractContextManager[None]: # noqa: N802 ... @property - def next(self) -> NoReturn: - ... + def next(self) -> NoReturn: ... @next.setter - def next(self, name: str) -> None: - ... + def next(self, name: str) -> None: ... class AbstractSignatureMembers(Protocol): - def flip(self) -> "AbstractSignatureMembers": - ... + def flip(self) -> "AbstractSignatureMembers": ... - def __eq__(self, other) -> bool: - ... + def __eq__(self, other) -> bool: ... - def __contains__(self, name: str) -> bool: - ... + def __contains__(self, name: str) -> bool: ... - def __getitem__(self, name: str) -> Member: - ... + def __getitem__(self, name: str) -> Member: ... - def __setitem__(self, name: str, member: Member) -> NoReturn: - ... + def __setitem__(self, name: str, member: Member) -> NoReturn: ... - def __delitem__(self, name: str) -> NoReturn: - ... + def __delitem__(self, name: str) -> NoReturn: ... - def __iter__(self) -> Iterator[str]: - ... + def __iter__(self) -> Iterator[str]: ... - def __len__(self) -> int: - ... + def __len__(self) -> int: ... - def flatten(self, *, path: tuple[str | int, ...] = ...) -> Iterator[tuple[tuple[str | int, ...], Member]]: - ... + def flatten(self, *, path: tuple[str | int, ...] = ...) -> Iterator[tuple[tuple[str | int, ...], Member]]: ... - def create(self, *, path: tuple[str | int, ...] = ..., src_loc_at: int = ...) -> dict[str, Any]: - ... + def create(self, *, path: tuple[str | int, ...] = ..., src_loc_at: int = ...) -> dict[str, Any]: ... - def __repr__(self) -> str: - ... + def __repr__(self) -> str: ... class AbstractSignature(Protocol): - def flip(self) -> "AbstractSignature": - ... + def flip(self) -> "AbstractSignature": ... @property - def members(self) -> AbstractSignatureMembers: - ... + def members(self) -> AbstractSignatureMembers: ... - def __eq__(self, other) -> bool: - ... + def __eq__(self, other) -> bool: ... - def flatten(self, obj) -> Iterator[tuple[tuple[str | int, ...], Flow, ValueLike]]: - ... + def flatten(self, obj) -> Iterator[tuple[tuple[str | int, ...], Flow, ValueLike]]: ... - def is_compliant(self, obj, *, reasons: Optional[list[str]] = ..., path: tuple[str, ...] = ...) -> bool: - ... + def is_compliant(self, obj, *, reasons: Optional[list[str]] = ..., path: tuple[str, ...] = ...) -> bool: ... def create( self, *, path: tuple[str | int, ...] = ..., src_loc_at: int = ... - ) -> "AbstractInterface[AbstractSignature]": - ... + ) -> "AbstractInterface[AbstractSignature]": ... - def __repr__(self) -> str: - ... + def __repr__(self) -> str: ... _T_AbstractSignature = TypeVar("_T_AbstractSignature", bound=AbstractSignature) @@ -211,14 +186,12 @@ class AbstractInterface(Protocol, Generic[_T_AbstractSignature]): class HasElaborate(Protocol): - def elaborate(self, platform) -> "HasElaborate": - ... + def elaborate(self, platform) -> "HasElaborate": ... @runtime_checkable class HasDebugSignals(Protocol): - def debug_signals(self) -> SignalBundle: - ... + def debug_signals(self) -> SignalBundle: ... def type_self_kwargs_as(as_func: Callable[Concatenate[Any, P], Any]): diff --git a/transactron/utils/amaranth_ext/elaboratables.py b/transactron/utils/amaranth_ext/elaboratables.py index 3af4ded98..b0ddbae35 100644 --- a/transactron/utils/amaranth_ext/elaboratables.py +++ b/transactron/utils/amaranth_ext/elaboratables.py @@ -59,13 +59,11 @@ def case(n: Optional[int] = None): @overload -def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: Literal[True]) -> Iterable[Optional[int]]: - ... +def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: Literal[True]) -> Iterable[Optional[int]]: ... @overload -def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: Literal[False] = False) -> Iterable[int]: - ... +def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: Literal[False] = False) -> Iterable[int]: ... def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: bool = False) -> Iterable[Optional[int]]: