diff --git a/coreblocks/backend/retirement.py b/coreblocks/backend/retirement.py index d68205c05..c5bd39d64 100644 --- a/coreblocks/backend/retirement.py +++ b/coreblocks/backend/retirement.py @@ -162,6 +162,7 @@ def flush_instr(rob_entry): # Register RISC-V architectural trap in CSRs m_csr.mcause.write(m, cause_entry) m_csr.mepc.write(m, cause_register.pc) + m_csr.mtval.write(m, cause_register.mtval) self.trap_entry(m) # Fetch is already stalled by ExceptionCauseRegister diff --git a/coreblocks/core.py b/coreblocks/core.py index 6aa9b5b8a..1a107622c 100644 --- a/coreblocks/core.py +++ b/coreblocks/core.py @@ -20,7 +20,7 @@ from coreblocks.core_structs.rf import RegisterFile from coreblocks.priv.csr.csr_instances import GenericCSRRegisters from coreblocks.frontend.frontend import CoreFrontend -from coreblocks.priv.traps.exception import ExceptionCauseRegister +from coreblocks.priv.traps.exception import ExceptionInformationRegister from coreblocks.scheduler.scheduler import Scheduler from coreblocks.backend.annoucement import ResultAnnouncement from coreblocks.backend.retirement import Retirement @@ -69,7 +69,7 @@ def __init__(self, *, gen_params: GenParams): self.connections.add_dependency(CommonBusDataKey(), self.bus_master_data_adapter) - self.exception_cause_register = ExceptionCauseRegister( + self.exception_information_register = ExceptionInformationRegister( self.gen_params, rob_get_indices=self.ROB.get_indices, fetch_stall_exception=self.frontend.stall, @@ -134,7 +134,7 @@ def elaborate(self, platform): gen_params=self.gen_params, ) - m.submodules.exception_cause_register = self.exception_cause_register + m.submodules.exception_information_register = self.exception_information_register fetch_resume_fb, fetch_resume_unifiers = self.connections.get_dependency(FetchResumeKey()) m.submodules.fetch_resume_unifiers = ModuleConnector(**fetch_resume_unifiers) @@ -151,8 +151,8 @@ def elaborate(self, platform): r_rat_peek=rrat.peek, free_rf_put=free_rf_fifo.write, rf_free=rf.free, - exception_cause_get=self.exception_cause_register.get, - exception_cause_clear=self.exception_cause_register.clear, + exception_cause_get=self.exception_information_register.get, + exception_cause_clear=self.exception_information_register.clear, frat_rename=frat.rename, fetch_continue=self.frontend.resume_from_exception, instr_decrement=core_counter.decrement, diff --git a/coreblocks/frontend/decoder/decode_stage.py b/coreblocks/frontend/decoder/decode_stage.py index 7064aa691..4d09e8de0 100644 --- a/coreblocks/frontend/decoder/decode_stage.py +++ b/coreblocks/frontend/decoder/decode_stage.py @@ -65,9 +65,9 @@ def elaborate(self, platform): ] exception_override = Signal() - m.d.comb += exception_override.eq(instr_decoder.illegal | raw.access_fault) + m.d.comb += exception_override.eq(instr_decoder.illegal | raw.access_fault.any()) exception_funct = Signal(Funct3) - with m.If(raw.access_fault): + with m.If(raw.access_fault.any()): m.d.comb += exception_funct.eq(Funct3._EINSTRACCESSFAULT) with m.Elif(instr_decoder.illegal): self.perf_illegal_instr.incr(m) @@ -95,7 +95,15 @@ def elaborate(self, platform): "rl_s1": Mux(instr_decoder.rs1_v & (~exception_override), instr_decoder.rs1, 0), "rl_s2": Mux(instr_decoder.rs2_v & (~exception_override), instr_decoder.rs2, 0), }, - "imm": instr_decoder.imm, + "imm": Mux( + ~exception_override, + instr_decoder.imm, + Mux( + raw.access_fault.any(), + raw.access_fault, # pass access fault details to FU + raw.instr, # illegal instruction - pass raw instruction bits for `mtval` + ), + ), "csr": instr_decoder.csr, "pc": raw.pc, }, diff --git a/coreblocks/frontend/decoder/instr_decoder.py b/coreblocks/frontend/decoder/instr_decoder.py index 69574fa70..35b83ccbe 100644 --- a/coreblocks/frontend/decoder/instr_decoder.py +++ b/coreblocks/frontend/decoder/instr_decoder.py @@ -317,6 +317,14 @@ def elaborate(self, platform): self.rs1_v.eq(0), ] + # HACK: pass logical registers in unused high bits of CSR instruction for `mtval` reconstruction + with m.If((self.optype == OpType.CSR_REG) | (self.optype == OpType.CSR_IMM)): + m.d.comb += self.imm[32 - self.gen_params.isa.reg_cnt_log : 32].eq(self.rd) + m.d.comb += self.imm[32 - self.gen_params.isa.reg_cnt_log * 2 : 32 - self.gen_params.isa.reg_cnt_log].eq( + self.rs1 + ) + assert 32 - self.gen_params.isa.reg_cnt_log * 2 >= 5 + # Instruction simplification # lui rd, imm -> addi rd, x0, (imm << 12) diff --git a/coreblocks/frontend/decoder/rvc.py b/coreblocks/frontend/decoder/rvc.py index 00e244daa..c852b39e6 100644 --- a/coreblocks/frontend/decoder/rvc.py +++ b/coreblocks/frontend/decoder/rvc.py @@ -291,6 +291,8 @@ def elaborate(self, platform): res = self.instr_mux(quadrant, quadrants) - m.d.comb += self.instr_out.eq(Mux(res[1], res[0], IllegalInstr())) + # In case of illegal instruction, output `instr_in` to be able to save it into `mtval` CSR. + # Decoder would still recognize it as illegal because of quadrant != 0b11 + m.d.comb += self.instr_out.eq(Mux(res[1], res[0], self.instr_in)) return m diff --git a/coreblocks/frontend/fetch/fetch.py b/coreblocks/frontend/fetch/fetch.py index 21122199d..efe1b39d0 100644 --- a/coreblocks/frontend/fetch/fetch.py +++ b/coreblocks/frontend/fetch/fetch.py @@ -323,19 +323,28 @@ def flush(): m.d.av_comb += [ raw_instrs[i].instr.eq(instrs[i]), raw_instrs[i].pc.eq(params.pc_from_fb(fetch_block_addr, i)), - raw_instrs[i].access_fault.eq(access_fault), raw_instrs[i].rvc.eq(s1_data.rvc[i]), raw_instrs[i].predicted_taken.eq(redirect & (predcheck_res.fb_instr_idx == i)), + raw_instrs[i].access_fault.eq( + Mux(s1_data.access_fault, FetchLayouts.AccessFaultFlag.ACCESS_FAULT, 0) + ), ] if Extension.C in self.gen_params.isa.extensions: with m.If(s1_data.instr_block_cross): m.d.av_comb += raw_instrs[0].pc.eq(params.pc_from_fb(fetch_block_addr, 0) - 2) + with m.If(s1_data.access_fault): + # Mark that access fault happened only at second (current) half. + # If fault happened on the first half `instr_block_cross` would be false + m.d.av_comb += raw_instrs[0].access_fault.eq( + FetchLayouts.AccessFaultFlag.ACCESS_FAULT_ON_SECOND_HALF + ) with condition(m) as branch: with branch(flushing_counter == 0): with m.If(access_fault | unsafe_stall): # TODO: Raise different code for page fault when supported + # could be passed in 3rd bit of access_fault flush() m.d.sync += stalled_unsafe.eq(1) with m.Elif(redirect): @@ -523,6 +532,7 @@ def elaborate(self, platform): @def_method(m, self.predecode) def _(instr): + quadrant = instr[0:2] opcode = instr[2:7] funct3 = instr[12:15] rd = instr[7:12] @@ -557,6 +567,9 @@ def _(instr): with m.Default(): m.d.av_comb += ret.cfi_type.eq(CfiType.INVALID) + with m.If(quadrant != 0b11): + m.d.av_comb += ret.cfi_type.eq(CfiType.INVALID) + m.d.av_comb += ret.unsafe.eq( (opcode == Opcode.SYSTEM) | ((opcode == Opcode.MISC_MEM) & (funct3 == Funct3.FENCEI)) ) diff --git a/coreblocks/func_blocks/csr/csr.py b/coreblocks/func_blocks/csr/csr.py index a82074256..b59ae18f8 100644 --- a/coreblocks/func_blocks/csr/csr.py +++ b/coreblocks/func_blocks/csr/csr.py @@ -8,6 +8,7 @@ from transactron.utils.dependencies import DependencyContext from coreblocks.arch import OpType, Funct3, ExceptionCause, PrivilegeLevel +from coreblocks.arch.isa_consts import Opcode from coreblocks.params import GenParams from coreblocks.params.fu_params import BlockComponentParams from coreblocks.func_blocks.interface.func_protocols import FuncBlock @@ -189,7 +190,7 @@ def _(rs_entry_id, rs_data): m.d.sync += assign(instr, rs_data) with m.If(rs_data.exec_fn.op_type == OpType.CSR_IMM): # Pass immediate as first operand - m.d.sync += instr.s1_val.eq(rs_data.imm) + m.d.sync += instr.s1_val.eq(rs_data.imm[0:5]) m.d.sync += instr.valid.eq(1) @@ -209,7 +210,21 @@ def _(): interrupt = self.dependency_manager.get_dependency(AsyncInterruptInsertSignalKey()) with m.If(exception): - report(m, rob_id=instr.rob_id, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=instr.pc) + mtval = Signal(self.gen_params.isa.xlen) + # re-encode the CSR instruction to speed-up missing CSR emulation (optional, otherwise mtval must be 0) + m.d.av_comb += mtval[0:2].eq(0b11) + m.d.av_comb += mtval[2:7].eq(Opcode.SYSTEM) + m.d.av_comb += mtval[7:12].eq(instr.imm[32 - self.gen_params.isa.reg_cnt_log : 32]) # rl_rd + m.d.av_comb += mtval[12:15].eq(instr.exec_fn.funct3) + m.d.av_comb += mtval[15:20].eq( + Mux( + instr.exec_fn.op_type == OpType.CSR_IMM, + instr.imm[0:5], + instr.imm[32 - self.gen_params.isa.reg_cnt_log * 2 : 32 - self.gen_params.isa.reg_cnt_log], + ) + ) # rl_s1 or imm + m.d.av_comb += mtval[20:32].eq(instr.csr) + report(m, rob_id=instr.rob_id, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=instr.pc, mtval=mtval) with m.Elif(interrupt): # SPEC: "These conditions for an interrupt trap to occur [..] must also be evaluated immediately # following [..] an explicit write to a CSR on which these interrupt trap conditions expressly depend." @@ -221,6 +236,7 @@ def _(): rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=instr.pc + self.gen_params.isa.ilen_bytes, + mtval=0, ) m.d.sync += exception.eq(0) diff --git a/coreblocks/func_blocks/fu/exception.py b/coreblocks/func_blocks/fu/exception.py index e4b67e6a8..4d5ba2ff7 100644 --- a/coreblocks/func_blocks/fu/exception.py +++ b/coreblocks/func_blocks/fu/exception.py @@ -8,7 +8,7 @@ from coreblocks.params import GenParams, FunctionalComponentParams from coreblocks.arch import OpType, Funct3, ExceptionCause -from coreblocks.interface.layouts import FuncUnitLayouts +from coreblocks.interface.layouts import FetchLayouts, FuncUnitLayouts from transactron.utils import OneHotSwitch from coreblocks.interface.keys import ExceptionReportKey, CSRInstancesKey @@ -69,28 +69,39 @@ def _(arg): m.d.comb += decoder.exec_fn.eq(arg.exec_fn) cause = Signal(ExceptionCause) + mtval = Signal(self.gen_params.isa.xlen) priv_level = self.dm.get_dependency(CSRInstancesKey()).m_mode.priv_mode.read(m).data with OneHotSwitch(m, decoder.decode_fn) as OneHotCase: with OneHotCase(ExceptionUnitFn.Fn.EBREAK): - m.d.comb += cause.eq(ExceptionCause.BREAKPOINT) + m.d.av_comb += cause.eq(ExceptionCause.BREAKPOINT) + m.d.av_comb += mtval.eq(arg.pc) with OneHotCase(ExceptionUnitFn.Fn.ECALL): with m.Switch(priv_level): with m.Case(PrivilegeLevel.MACHINE): - m.d.comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_M) + m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_M) with m.Case(PrivilegeLevel.USER): - m.d.comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_U) + m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_U) + m.d.av_comb += mtval.eq(0) # by SPEC with OneHotCase(ExceptionUnitFn.Fn.INSTR_ACCESS_FAULT): - m.d.comb += cause.eq(ExceptionCause.INSTRUCTION_ACCESS_FAULT) + m.d.av_comb += cause.eq(ExceptionCause.INSTRUCTION_ACCESS_FAULT) + # With C extension access fault can be only on the second half of instruction, and mepc != mtval. + # This information is passed in imm field + m.d.av_comb += mtval.eq( + arg.pc + ((arg.imm & FetchLayouts.AccessFaultFlag.ACCESS_FAULT_ON_SECOND_HALF).any() << 1) + ) with OneHotCase(ExceptionUnitFn.Fn.ILLEGAL_INSTRUCTION): - m.d.comb += cause.eq(ExceptionCause.ILLEGAL_INSTRUCTION) + m.d.av_comb += cause.eq(ExceptionCause.ILLEGAL_INSTRUCTION) + m.d.av_comb += mtval.eq(arg.imm) # passed instruction bytes with OneHotCase(ExceptionUnitFn.Fn.BREAKPOINT): - m.d.comb += cause.eq(ExceptionCause.BREAKPOINT) + m.d.av_comb += cause.eq(ExceptionCause.BREAKPOINT) + m.d.av_comb += mtval.eq(arg.pc) with OneHotCase(ExceptionUnitFn.Fn.INSTR_PAGE_FAULT): - m.d.comb += cause.eq(ExceptionCause.INSTRUCTION_PAGE_FAULT) + m.d.av_comb += cause.eq(ExceptionCause.INSTRUCTION_PAGE_FAULT) + m.d.av_comb += mtval.eq(arg.pc + (arg.imm[1] << 1)) - self.report(m, rob_id=arg.rob_id, cause=cause, pc=arg.pc) + self.report(m, rob_id=arg.rob_id, cause=cause, pc=arg.pc, mtval=mtval) fifo.write(m, result=0, exception=1, rob_id=arg.rob_id, rp_dst=arg.rp_dst) diff --git a/coreblocks/func_blocks/fu/jumpbranch.py b/coreblocks/func_blocks/fu/jumpbranch.py index 2d164d05a..6cc31f66f 100644 --- a/coreblocks/func_blocks/fu/jumpbranch.py +++ b/coreblocks/func_blocks/fu/jumpbranch.py @@ -214,8 +214,13 @@ def _(): # generated for a conditional branch that is not taken." m.d.comb += exception.eq(1) exception_report( - m, rob_id=instr.rob_id, cause=ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED, pc=instr.pc + m, + rob_id=instr.rob_id, + cause=ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED, + pc=instr.pc, + mtval=instr.jmp_addr, ) + with m.Elif(async_interrupt_active & ~is_auipc): # Jump instructions are entry points for async interrupts. # This way we can store known pc via report to global exception register and avoid it in ROB. @@ -223,13 +228,15 @@ def _(): # and exception would be lost. m.d.comb += exception.eq(1) exception_report( - m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=jump_result + m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=jump_result, mtval=0 ) with m.Elif(misprediction): # Async interrupts can have priority, because `jump_result` is handled in the same way. # No extra misprediction penalty will be introducted at interrupt return to `jump_result` address. m.d.comb += exception.eq(1) - exception_report(m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_MISPREDICTION, pc=jump_result) + exception_report( + m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_MISPREDICTION, pc=jump_result, mtval=0 + ) with m.If(~is_auipc): self.fifo_branch_resolved.write(m, from_pc=instr.pc, next_pc=jump_result, misprediction=misprediction) diff --git a/coreblocks/func_blocks/fu/lsu/dummyLsu.py b/coreblocks/func_blocks/fu/lsu/dummyLsu.py index e807cc636..cef1daa4e 100644 --- a/coreblocks/func_blocks/fu/lsu/dummyLsu.py +++ b/coreblocks/func_blocks/fu/lsu/dummyLsu.py @@ -77,7 +77,7 @@ def _(arg): with m.If(~is_fence): requests.write(m, arg) with m.Else(): - results_noop.write(m, data=0, exception=0, cause=0) + results_noop.write(m, data=0, exception=0, cause=0, addr=0) issued_noop.write(m, arg) # Issues load/store requests when the instruction is known, is a LOAD/STORE, and just before commit. @@ -106,14 +106,14 @@ def _(arg): with m.If(res["exception"]): issued_noop.write(m, arg) - results_noop.write(m, data=0, exception=res["exception"], cause=res["cause"]) + results_noop.write(m, data=0, exception=res["exception"], cause=res["cause"], addr=addr) with m.Else(): issued.write(m, arg) # Handles flushed instructions as a no-op. with Transaction().body(m, request=flush): arg = requests.read(m) - results_noop.write(m, data=0, exception=0, cause=0) + results_noop.write(m, data=0, exception=0, cause=0, addr=0) issued_noop.write(m, arg) @def_method(m, self.accept) @@ -129,7 +129,7 @@ def _(): m.d.comb += arg.eq(issued_noop.read(m)) with m.If(res["exception"]): - self.report(m, rob_id=arg["rob_id"], cause=res["cause"], pc=arg["pc"]) + self.report(m, rob_id=arg["rob_id"], cause=res["cause"], pc=arg["pc"], mtval=res["addr"]) self.log.debug(m, 1, "accept rob_id={} result=0x{:08x} exception={}", arg.rob_id, res.data, res.exception) diff --git a/coreblocks/func_blocks/fu/lsu/lsu_requester.py b/coreblocks/func_blocks/fu/lsu/lsu_requester.py index 9176294ee..c8abd6017 100644 --- a/coreblocks/func_blocks/fu/lsu/lsu_requester.py +++ b/coreblocks/func_blocks/fu/lsu/lsu_requester.py @@ -172,6 +172,6 @@ def _(): Mux(request_args.store, ExceptionCause.STORE_ACCESS_FAULT, ExceptionCause.LOAD_ACCESS_FAULT) ) - return {"data": data, "exception": exception, "cause": cause} + return {"data": data, "exception": exception, "cause": cause, "addr": request_args.addr} return m diff --git a/coreblocks/func_blocks/fu/priv.py b/coreblocks/func_blocks/fu/priv.py index 5c90bbc5f..268c50d14 100644 --- a/coreblocks/func_blocks/fu/priv.py +++ b/coreblocks/func_blocks/fu/priv.py @@ -2,7 +2,7 @@ from enum import IntFlag, auto, unique from typing import Sequence -from coreblocks.arch.isa_consts import PrivilegeLevel +from coreblocks.arch.isa_consts import Funct12, Funct3, Opcode, PrivilegeLevel from transactron import * @@ -45,17 +45,17 @@ def get_instructions(cls) -> Sequence[tuple]: class PrivilegedFuncUnit(Elaboratable): - def __init__(self, gp: GenParams): - self.gp = gp + def __init__(self, gen_params: GenParams): + self.gen_params = gen_params self.priv_fn = PrivilegedFn() - self.layouts = layouts = gp.get(FuncUnitLayouts) + self.layouts = layouts = gen_params.get(FuncUnitLayouts) self.dm = DependencyContext.get() self.issue = Method(i=layouts.issue) self.accept = Method(o=layouts.accept) - self.fetch_resume_fifo = BasicFifo(self.gp.get(FetchLayouts).resume, 2) + self.fetch_resume_fifo = BasicFifo(self.gen_params.get(FetchLayouts).resume, 2) self.perf_instr = TaggedCounter( "backend.fu.priv.instr", @@ -68,14 +68,14 @@ def elaborate(self, platform): m.submodules += [self.perf_instr] - m.submodules.decoder = decoder = self.priv_fn.get_decoder(self.gp) + m.submodules.decoder = decoder = self.priv_fn.get_decoder(self.gen_params) instr_valid = Signal() finished = Signal() illegal_instruction = Signal() - instr_rob = Signal(self.gp.rob_entries_bits) - instr_pc = Signal(self.gp.isa.xlen) + instr_rob = Signal(self.gen_params.rob_entries_bits) + instr_pc = Signal(self.gen_params.isa.xlen) instr_fn = self.priv_fn.get_function() mret = self.dm.get_dependency(MretKey()) @@ -129,7 +129,7 @@ def _(): m.d.sync += instr_valid.eq(0) m.d.sync += finished.eq(0) - ret_pc = Signal(self.gp.isa.xlen) + ret_pc = Signal(self.gen_params.isa.xlen) with OneHotSwitch(m, instr_fn) as OneHotCase: with OneHotCase(PrivilegedFn.Fn.MRET): @@ -145,8 +145,21 @@ def _(): exception = Signal() with m.If(illegal_instruction): - m.d.comb += exception.eq(1) - exception_report(m, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=ret_pc, rob_id=instr_rob) + m.d.av_comb += exception.eq(1) + + # Replace with const zero if turns out not worth to re-encode instruction + instr = Signal(self.gen_params.isa.xlen) + m.d.av_comb += instr[0:2].eq(0b11) + m.d.av_comb += instr[2:7].eq(Opcode.SYSTEM) + m.d.av_comb += instr[7:12].eq(0) + m.d.av_comb += instr[12:15].eq(Funct3.PRIV) + m.d.av_comb += instr[15:20].eq(0) + m.d.av_comb += instr[20:32].eq(Mux(instr_fn == PrivilegedFn.Fn.MRET, Funct12.WFI, Funct12.MRET)) + log.error( + m, (instr_fn != PrivilegedFn.Fn.MRET) & (instr_fn != PrivilegedFn.Fn.WFI), "missing Funct12 case" + ) + + exception_report(m, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=ret_pc, rob_id=instr_rob, mtval=instr) with m.Elif(async_interrupt_active): # SPEC: "These conditions for an interrupt trap to occur [..] must also be evaluated immediately # following the execution of an xRET instruction." @@ -155,8 +168,10 @@ def _(): # by updated async_interrupt_active signal. # Interrupt is reported on this xRET instruction with return address set to instruction that we # would normally return to (mepc value is preserved) - m.d.comb += exception.eq(1) - exception_report(m, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=ret_pc, rob_id=instr_rob) + m.d.av_comb += exception.eq(1) + exception_report( + m, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=ret_pc, rob_id=instr_rob, mtval=0 + ) with m.Else(): log.info(m, True, "Unstalling fetch from the priv unit new_pc=0x{:x}", ret_pc) # Unstall the fetch diff --git a/coreblocks/interface/layouts.py b/coreblocks/interface/layouts.py index 31f42a2e2..a19ecca45 100644 --- a/coreblocks/interface/layouts.py +++ b/coreblocks/interface/layouts.py @@ -1,10 +1,11 @@ from typing import Optional from amaranth import signed -from amaranth.lib.data import StructLayout, ArrayLayout +from amaranth.lib.data import ArrayLayout +from amaranth.lib.enum import IntFlag, auto from coreblocks.params import GenParams from coreblocks.arch import * from transactron.utils import LayoutList, LayoutListField, layout_subset -from transactron.utils.transactron_helpers import from_method_layout, make_layout +from transactron.utils.transactron_helpers import from_method_layout, make_layout, extend_layout __all__ = [ "CommonLayoutFields", @@ -441,11 +442,20 @@ def __init__(self, gen_params: GenParams): class FetchLayouts: """Layouts used in the fetcher.""" + class AccessFaultFlag(IntFlag): + # standard access fault when accessing instruction + # from beginning (exception pc = instruction pc) (fault on full instruction or first half) + ACCESS_FAULT = auto() + # with C extension (2-byte alignment enabled) fault condition + # could only affect second half of 4-byte instruction. + # Bit set if this is the case + ACCESS_FAULT_ON_SECOND_HALF = auto() + def __init__(self, gen_params: GenParams): fields = gen_params.get(CommonLayoutFields) - self.access_fault: LayoutListField = ("access_fault", 1) - """Instruction fetch failed.""" + self.access_fault: LayoutListField = ("access_fault", FetchLayouts.AccessFaultFlag) + """Instruction fetch errors. See `FetchLayouts.AccessFaultFlag` fields documentation""" self.raw_instr = make_layout( fields.instr, @@ -581,7 +591,7 @@ def __init__(self, gen_params: GenParams): self.issue_out = make_layout(fields.exception, fields.cause) - self.accept = make_layout(fields.data, fields.exception, fields.cause) + self.accept = make_layout(fields.data, fields.exception, fields.cause, fields.addr) class CSRRegisterLayouts: @@ -634,20 +644,24 @@ def __init__(self, gen_params: GenParams): class ExceptionRegisterLayouts: - """Layouts used in the exception register.""" + """Layouts used in the exception information register.""" def __init__(self, gen_params: GenParams): fields = gen_params.get(CommonLayoutFields) + self.mtval: LayoutListField = ("mtval", gen_params.isa.xlen) + """ Value to set for mtval CSR register """ + self.valid: LayoutListField = ("valid", 1) self.report = make_layout( fields.cause, fields.rob_id, fields.pc, + self.mtval, ) - self.get = StructLayout(self.report.members | make_layout(self.valid).members) + self.get = extend_layout(self.report, self.valid) class InternalInterruptControllerLayouts: diff --git a/coreblocks/priv/csr/csr_instances.py b/coreblocks/priv/csr/csr_instances.py index ec254739a..3215b42ee 100644 --- a/coreblocks/priv/csr/csr_instances.py +++ b/coreblocks/priv/csr/csr_instances.py @@ -82,6 +82,8 @@ def __init__(self, gen_params: GenParams): mepc_ro_bits = 0b1 if Extension.C in gen_params.isa.extensions else 0b11 # pc alignment (SPEC) self.mepc = CSRRegister(CSRAddress.MEPC, gen_params, ro_bits=mepc_ro_bits) + self.mtval = CSRRegister(CSRAddress.MTVAL, gen_params) + self.priv_mode = CSRRegister( None, gen_params, diff --git a/coreblocks/priv/traps/exception.py b/coreblocks/priv/traps/exception.py index d2f34f43e..9898626e1 100644 --- a/coreblocks/priv/traps/exception.py +++ b/coreblocks/priv/traps/exception.py @@ -40,8 +40,8 @@ def should_update_prioriy(m: TModule, current_cause: Value, new_cause: Value) -> return _update -class ExceptionCauseRegister(Elaboratable): - """ExceptionCauseRegister +class ExceptionInformationRegister(Elaboratable): + """ExceptionInformationRegister Stores parameters of earliest (in instruction order) exception, to save resources in the `ReorderBuffer`. All FUs that report exceptions should `report` the details to `ExceptionCauseRegister` and set `exception` bit in @@ -56,6 +56,7 @@ def __init__(self, gen_params: GenParams, rob_get_indices: Method, fetch_stall_e self.cause = Signal(ExceptionCause) self.rob_id = Signal(gen_params.rob_entries_bits) self.pc = Signal(gen_params.isa.xlen) + self.mtval = Signal(gen_params.isa.xlen) self.valid = Signal() self.layouts = gen_params.get(ExceptionRegisterLayouts) @@ -82,7 +83,7 @@ def elaborate(self, platform): m.submodules.report_connector = ConnectTrans(self.fu_report_fifo.read, report) @def_method(m, report) - def _(cause, rob_id, pc): + def _(cause, rob_id, pc, mtval): should_write = Signal() with m.If(self.valid & (self.rob_id == rob_id)): @@ -101,6 +102,7 @@ def _(cause, rob_id, pc): m.d.sync += self.rob_id.eq(rob_id) m.d.sync += self.cause.eq(cause) m.d.sync += self.pc.eq(pc) + m.d.sync += self.mtval.eq(mtval) m.d.sync += self.valid.eq(1) @@ -109,7 +111,7 @@ def _(cause, rob_id, pc): @def_method(m, self.get) def _(): - return {"rob_id": self.rob_id, "cause": self.cause, "pc": self.pc, "valid": self.valid} + return {"rob_id": self.rob_id, "cause": self.cause, "pc": self.pc, "mtval": self.mtval, "valid": self.valid} @def_method(m, self.clear) def _(): diff --git a/test/asm/mtval.asm b/test/asm/mtval.asm new file mode 100644 index 000000000..d56fffe0b --- /dev/null +++ b/test/asm/mtval.asm @@ -0,0 +1,105 @@ +# test `mtval` and `mcause` CSR values for various excpetions +# C extension is required in the core, but must be disabled in the toolchain + + la x1, handler + csrw mtvec, x1 + li x8, 0 + + li x7, 0x80000000 +c0: # load from illegal address. mtval=addr mcause=LOAD_ACCESS_FAULT + lw x1, 0x230(x7) +c1: # mtval=pc mcause=BREAKPOINT + ebreak +c2: # instruction address out of memory mtval=i_out_of_range mcause=INSTRUCTION_ACCESS_FAULT + j i_out_of_range +c3: # jump to 2-byte aligned, 4-byte long instruction, of which first two bytes are available + # and other half is outside of memory range. mtval=i_partial_out_of_range+2 mcause=INSTRUCTION_ACCESS_FAULT + j i_partial_out_of_range +c4: # illegal 4-byte instruction ([0:2] = 0b11) mtval=raw instruction mcause=ILLEGAL_INSTRUCTION +.word 0x43 +c5: # illegal compressed type ([0:2] != 0b11) instruction mtval=raw instruction mcause=ILLEGAL_INSTRUCTION +.word 0x8000 +c6: # access to missing csr mtval=raw instruction mcause=ILLEGAL_INSTRUCTION + csrr x1, 0x123 +c7: # access to missing csr mtval=raw instruction mcause=ILLEGAL_INSTRUCTION + csrwi 0x123, 8 +c8: # store to misaligned address mtvak=addr mcause=STORE_ADDRESS_MISALIGNED + sw x1, 0x231(x7) +c9: # mtval=0 mcause=ENVIRONMENT_CALL_FROM_M + ecall + +pass: + j pass + + +handler: # test each case. test case number = in x8>>2 + la x1, excpected_mtval + add x1, x1, x8 + lw x2, (x1) + csrr x1, mtval + bne x1, x2, fail + + la x1, excpected_mcause + add x1, x1, x8 + lw x2, (x1) + csrr x1, mcause + bne x1, x2, fail + + la x1, next_instr + add x1, x1, x8 + lw x2, (x1) + csrw mepc, x2 + + addi x8, x8, 4 + + mret + +fail: + j fail + +# it is legal - C is enabled in core, but can't be enabled in toolchain to keep 4-byte nops +.org 0x0FFE +i_partial_out_of_range: +nop +i_out_of_range: +nop + +.data + +excpected_mtval: +.word 0x80000230 +.word c1 +.word i_out_of_range +.word i_partial_out_of_range + 2 +.word 0x43 +.word 0x8000 +.word 0x123020f3 +.word 0x12345073 +.word 0x80000231 +.word 0 + +excpected_mcause: +.word 5 +.word 3 +.word 1 +.word 1 +.word 2 +.word 2 +.word 2 +.word 2 +.word 6 +.word 11 +# testing misaligned instr branch is not possible with C enabled :( + +next_instr: +.word c1 +.word c2 +.word c3 +.word c4 +.word c5 +.word c6 +.word c7 +.word c8 +.word c9 +.word pass + diff --git a/test/frontend/test_instr_decoder.py b/test/frontend/test_instr_decoder.py index 2abd84c9d..1830f1063 100644 --- a/test/frontend/test_instr_decoder.py +++ b/test/frontend/test_instr_decoder.py @@ -217,7 +217,11 @@ def process(): assert (yield self.decoder.rs2_v) == (test.rs2 is not None) if test.imm is not None: - assert (yield self.decoder.imm.as_signed()) == test.imm + if test.csr is not None: + # in CSR instruction additional fields are passed in unused bits of imm field + assert (yield self.decoder.imm.as_signed() & ((2**5) - 1)) == test.imm + else: + assert (yield self.decoder.imm.as_signed()) == test.imm if test.succ is not None: assert (yield self.decoder.succ) == test.succ diff --git a/test/frontend/test_rvc.py b/test/frontend/test_rvc.py index fb62b4ed2..f1690f8dd 100644 --- a/test/frontend/test_rvc.py +++ b/test/frontend/test_rvc.py @@ -284,12 +284,19 @@ def test(self): self.m = InstrDecompress(self.gen_params) def process(): + illegal = Signal(32) + yield illegal.eq(IllegalInstr()) + for instr_in, instr_out in self.test_cases: yield self.m.instr_in.eq(instr_in) expected = Signal(32) yield expected.eq(instr_out) yield Settle() + if (yield expected) == (yield illegal): + yield expected.eq(instr_in) # for exception handling + yield Settle() + assert (yield self.m.instr_out) == (yield expected) yield Tick() diff --git a/test/func_blocks/csr/test_csr.py b/test/func_blocks/csr/test_csr.py index 001c031e3..340afeda9 100644 --- a/test/func_blocks/csr/test_csr.py +++ b/test/func_blocks/csr/test_csr.py @@ -103,7 +103,7 @@ def generate_instruction(self): rd = random.randint(0, 15) rs1 = 0 if imm_op else random.randint(0, 15) - imm = random.randint(0, 2**self.gen_params.isa.xlen - 1) + imm = random.randint(0, 2**5 - 1) rs1_val = random.randint(0, 2**self.gen_params.isa.xlen - 1) if rs1 else 0 operand_val = imm if imm_op else rs1_val csr = random.choice(list(self.dut.csr.keys())) @@ -207,7 +207,8 @@ def process_exception_test(self): assert res["exception"] == 1 report = yield from self.dut.exception_report.call_result() - assert report is not None + assert isinstance(report, dict) + report.pop("mtval") # mtval tested in mtval.asm test assert {"rob_id": rob_id, "cause": ExceptionCause.ILLEGAL_INSTRUCTION, "pc": 0} == report def test_exception(self): diff --git a/test/func_blocks/fu/functional_common.py b/test/func_blocks/fu/functional_common.py index 9426cad70..088c4337d 100644 --- a/test/func_blocks/fu/functional_common.py +++ b/test/func_blocks/fu/functional_common.py @@ -143,10 +143,18 @@ def setup(self, fixture_initialize_testing_env): cause = None if "exception" in results: cause = results["exception"] - self.exceptions.append({"rob_id": rob_id, "cause": cause, "pc": results.setdefault("exception_pc", pc)}) + self.exceptions.append( + { + "rob_id": rob_id, + "cause": cause, + "pc": results.setdefault("exception_pc", pc), + "mtval": results.setdefault("mtval", 0), + } + ) results.pop("exception") results.pop("exception_pc") + results.pop("mtval") self.responses.append({"rob_id": rob_id, "rp_dst": rp_dst, "exception": int(cause is not None)} | results) diff --git a/test/func_blocks/fu/test_exception_unit.py b/test/func_blocks/fu/test_exception_unit.py index 3f6793c4b..ad57bb896 100644 --- a/test/func_blocks/fu/test_exception_unit.py +++ b/test/func_blocks/fu/test_exception_unit.py @@ -22,20 +22,25 @@ class TestExceptionUnit(FunctionalUnitTestCase[ExceptionUnitFn.Fn]): @staticmethod def compute_result(i1: int, i2: int, i_imm: int, pc: int, fn: ExceptionUnitFn.Fn, xlen: int) -> dict[str, int]: cause = None + mtval = 0 match fn: case ExceptionUnitFn.Fn.EBREAK | ExceptionUnitFn.Fn.BREAKPOINT: cause = ExceptionCause.BREAKPOINT + mtval = pc case ExceptionUnitFn.Fn.ECALL: cause = ExceptionCause.ENVIRONMENT_CALL_FROM_M case ExceptionUnitFn.Fn.INSTR_ACCESS_FAULT: cause = ExceptionCause.INSTRUCTION_ACCESS_FAULT + mtval = pc case ExceptionUnitFn.Fn.INSTR_PAGE_FAULT: cause = ExceptionCause.INSTRUCTION_PAGE_FAULT + mtval = pc case ExceptionUnitFn.Fn.ILLEGAL_INSTRUCTION: cause = ExceptionCause.ILLEGAL_INSTRUCTION + mtval = i_imm # in case of illegal instruction, raw instr bits are passed in imm field - return {"result": 0} | {"exception": cause} if cause is not None else {} + return {"result": 0} | {"exception": cause, "mtval": mtval} if cause is not None else {} def test_fu(self): self.run_standard_fu_test() diff --git a/test/func_blocks/fu/test_jb_unit.py b/test/func_blocks/fu/test_jb_unit.py index baa2759af..84ab5603d 100644 --- a/test/func_blocks/fu/test_jb_unit.py +++ b/test/func_blocks/fu/test_jb_unit.py @@ -120,14 +120,16 @@ def compute_result(i1: int, i2: int, i_imm: int, pc: int, fn: JumpBranchFn.Fn, x exception = None exception_pc = pc + mtval = 0 if next_pc & 0b11 != 0: exception = ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED + mtval = next_pc elif misprediction: exception = ExceptionCause._COREBLOCKS_MISPREDICTION exception_pc = next_pc return {"result": res, "from_pc": pc, "next_pc": next_pc, "misprediction": misprediction} | ( - {"exception": exception, "exception_pc": exception_pc} if exception is not None else {} + {"exception": exception, "exception_pc": exception_pc, "mtval": mtval} if exception is not None else {} ) diff --git a/test/func_blocks/lsu/test_dummylsu.py b/test/func_blocks/lsu/test_dummylsu.py index 1d92a50c3..976550a69 100644 --- a/test/func_blocks/lsu/test_dummylsu.py +++ b/test/func_blocks/lsu/test_dummylsu.py @@ -168,6 +168,7 @@ def generate_instr(self, max_reg_val, max_imm_val): ExceptionCause.LOAD_ADDRESS_MISALIGNED if misaligned else ExceptionCause.LOAD_ACCESS_FAULT ), "pc": 0, + "mtval": addr, } ) diff --git a/test/priv/traps/test_exception.py b/test/priv/traps/test_exception.py index bdf342327..22ebb8b5e 100644 --- a/test/priv/traps/test_exception.py +++ b/test/priv/traps/test_exception.py @@ -1,7 +1,7 @@ from amaranth import * from coreblocks.interface.layouts import ROBLayouts -from coreblocks.priv.traps.exception import ExceptionCauseRegister +from coreblocks.priv.traps.exception import ExceptionInformationRegister from coreblocks.params import GenParams from coreblocks.arch import ExceptionCause from coreblocks.params.configurations import test_core_config @@ -13,7 +13,7 @@ import random -class TestExceptionCauseRegister(TestCaseWithSimulator): +class TestExceptionInformationRegister(TestCaseWithSimulator): rob_max = 7 def should_update(self, new_arg, old_arg, rob_start) -> bool: @@ -36,7 +36,7 @@ def test_randomized(self): self.rob_idx_mock = TestbenchIO(Adapter(o=self.gen_params.get(ROBLayouts).get_indices)) self.fetch_stall_mock = TestbenchIO(Adapter()) self.dut = SimpleTestCircuit( - ExceptionCauseRegister( + ExceptionInformationRegister( self.gen_params, self.rob_idx_mock.adapter.iface, self.fetch_stall_mock.adapter.iface ) ) @@ -57,7 +57,8 @@ def process_test(): while saved_entry and report_rob == saved_entry["rob_id"]: report_rob = random.randint(0, self.rob_max) report_pc = random.randrange(2**self.gen_params.isa.xlen) - report_arg = {"cause": cause, "rob_id": report_rob, "pc": report_pc} + report_mtval = random.randrange(2**self.gen_params.isa.xlen) + report_arg = {"cause": cause, "rob_id": report_rob, "pc": report_pc, "mtval": report_mtval} expected = report_arg if self.should_update(report_arg, saved_entry, self.rob_id) else saved_entry yield from self.dut.report.call(report_arg) diff --git a/test/test_core.py b/test/test_core.py index eae0bcc85..1214d25e6 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -137,9 +137,11 @@ def load_section(section: str): ("exception_mem", "exception_mem.asm", 200, {1: 1, 2: 2}, basic_core_config), ("exception_handler", "exception_handler.asm", 2000, {2: 987, 11: 0xAAAA, 15: 16}, full_core_config), ("wfi_no_int", "wfi_no_int.asm", 200, {1: 1}, full_core_config), + ("mtval", "mtval.asm", 2000, {8: 5 * 8}, full_core_config), ], ) class TestCoreBasicAsm(TestCoreAsmSourceBase): + name: str source_file: str cycle_count: int expected_regvals: dict[int, int] @@ -156,6 +158,10 @@ def test_asm_source(self): self.gen_params = GenParams(self.configuration) bin_src = self.prepare_source(self.source_file) + + if self.name == "mtval": + bin_src["text"] = bin_src["text"][: 0x1000 // 4] # force instruction memory size clip in `mtval` test + self.m = CoreTestElaboratable(self.gen_params, instr_mem=bin_src["text"], data_mem=bin_src["data"]) with self.run_simulation(self.m) as sim: diff --git a/transactron/utils/transactron_helpers.py b/transactron/utils/transactron_helpers.py index 1d411788a..1fae8827a 100644 --- a/transactron/utils/transactron_helpers.py +++ b/transactron/utils/transactron_helpers.py @@ -19,6 +19,8 @@ "mock_def_helper", "get_src_loc", "from_method_layout", + "make_layout", + "extend_layout", ] T = TypeVar("T") @@ -147,6 +149,10 @@ def make_layout(*fields: LayoutListField) -> StructLayout: return from_method_layout(fields) +def extend_layout(layout: StructLayout, *fields: LayoutListField) -> StructLayout: + return StructLayout(layout.members | from_method_layout(fields).members) + + def from_method_layout(layout: MethodLayout) -> StructLayout: if isinstance(layout, StructLayout): return layout