Skip to content

Commit

Permalink
Implement full mtval (#712)
Browse files Browse the repository at this point in the history
  • Loading branch information
piotro888 authored Oct 15, 2024
1 parent 0e2887a commit 8f2072f
Show file tree
Hide file tree
Showing 26 changed files with 308 additions and 63 deletions.
1 change: 1 addition & 0 deletions coreblocks/backend/retirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def flush_instr(rob_entry):
# Register RISC-V architectural trap in CSRs
m_csr.mcause.write(m, cause_entry)
m_csr.mepc.write(m, cause_register.pc)
m_csr.mtval.write(m, cause_register.mtval)
self.trap_entry(m)

# Fetch is already stalled by ExceptionCauseRegister
Expand Down
10 changes: 5 additions & 5 deletions coreblocks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from coreblocks.core_structs.rf import RegisterFile
from coreblocks.priv.csr.csr_instances import GenericCSRRegisters
from coreblocks.frontend.frontend import CoreFrontend
from coreblocks.priv.traps.exception import ExceptionCauseRegister
from coreblocks.priv.traps.exception import ExceptionInformationRegister
from coreblocks.scheduler.scheduler import Scheduler
from coreblocks.backend.annoucement import ResultAnnouncement
from coreblocks.backend.retirement import Retirement
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, *, gen_params: GenParams):

self.connections.add_dependency(CommonBusDataKey(), self.bus_master_data_adapter)

self.exception_cause_register = ExceptionCauseRegister(
self.exception_information_register = ExceptionInformationRegister(
self.gen_params,
rob_get_indices=self.ROB.get_indices,
fetch_stall_exception=self.frontend.stall,
Expand Down Expand Up @@ -134,7 +134,7 @@ def elaborate(self, platform):
gen_params=self.gen_params,
)

m.submodules.exception_cause_register = self.exception_cause_register
m.submodules.exception_information_register = self.exception_information_register

fetch_resume_fb, fetch_resume_unifiers = self.connections.get_dependency(FetchResumeKey())
m.submodules.fetch_resume_unifiers = ModuleConnector(**fetch_resume_unifiers)
Expand All @@ -151,8 +151,8 @@ def elaborate(self, platform):
r_rat_peek=rrat.peek,
free_rf_put=free_rf_fifo.write,
rf_free=rf.free,
exception_cause_get=self.exception_cause_register.get,
exception_cause_clear=self.exception_cause_register.clear,
exception_cause_get=self.exception_information_register.get,
exception_cause_clear=self.exception_information_register.clear,
frat_rename=frat.rename,
fetch_continue=self.frontend.resume_from_exception,
instr_decrement=core_counter.decrement,
Expand Down
14 changes: 11 additions & 3 deletions coreblocks/frontend/decoder/decode_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def elaborate(self, platform):
]

exception_override = Signal()
m.d.comb += exception_override.eq(instr_decoder.illegal | raw.access_fault)
m.d.comb += exception_override.eq(instr_decoder.illegal | raw.access_fault.any())
exception_funct = Signal(Funct3)
with m.If(raw.access_fault):
with m.If(raw.access_fault.any()):
m.d.comb += exception_funct.eq(Funct3._EINSTRACCESSFAULT)
with m.Elif(instr_decoder.illegal):
self.perf_illegal_instr.incr(m)
Expand Down Expand Up @@ -95,7 +95,15 @@ def elaborate(self, platform):
"rl_s1": Mux(instr_decoder.rs1_v & (~exception_override), instr_decoder.rs1, 0),
"rl_s2": Mux(instr_decoder.rs2_v & (~exception_override), instr_decoder.rs2, 0),
},
"imm": instr_decoder.imm,
"imm": Mux(
~exception_override,
instr_decoder.imm,
Mux(
raw.access_fault.any(),
raw.access_fault, # pass access fault details to FU
raw.instr, # illegal instruction - pass raw instruction bits for `mtval`
),
),
"csr": instr_decoder.csr,
"pc": raw.pc,
},
Expand Down
8 changes: 8 additions & 0 deletions coreblocks/frontend/decoder/instr_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,14 @@ def elaborate(self, platform):
self.rs1_v.eq(0),
]

# HACK: pass logical registers in unused high bits of CSR instruction for `mtval` reconstruction
with m.If((self.optype == OpType.CSR_REG) | (self.optype == OpType.CSR_IMM)):
m.d.comb += self.imm[32 - self.gen_params.isa.reg_cnt_log : 32].eq(self.rd)
m.d.comb += self.imm[32 - self.gen_params.isa.reg_cnt_log * 2 : 32 - self.gen_params.isa.reg_cnt_log].eq(
self.rs1
)
assert 32 - self.gen_params.isa.reg_cnt_log * 2 >= 5

# Instruction simplification

# lui rd, imm -> addi rd, x0, (imm << 12)
Expand Down
4 changes: 3 additions & 1 deletion coreblocks/frontend/decoder/rvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def elaborate(self, platform):

res = self.instr_mux(quadrant, quadrants)

m.d.comb += self.instr_out.eq(Mux(res[1], res[0], IllegalInstr()))
# In case of illegal instruction, output `instr_in` to be able to save it into `mtval` CSR.
# Decoder would still recognize it as illegal because of quadrant != 0b11
m.d.comb += self.instr_out.eq(Mux(res[1], res[0], self.instr_in))

return m
15 changes: 14 additions & 1 deletion coreblocks/frontend/fetch/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,19 +323,28 @@ def flush():
m.d.av_comb += [
raw_instrs[i].instr.eq(instrs[i]),
raw_instrs[i].pc.eq(params.pc_from_fb(fetch_block_addr, i)),
raw_instrs[i].access_fault.eq(access_fault),
raw_instrs[i].rvc.eq(s1_data.rvc[i]),
raw_instrs[i].predicted_taken.eq(redirect & (predcheck_res.fb_instr_idx == i)),
raw_instrs[i].access_fault.eq(
Mux(s1_data.access_fault, FetchLayouts.AccessFaultFlag.ACCESS_FAULT, 0)
),
]

if Extension.C in self.gen_params.isa.extensions:
with m.If(s1_data.instr_block_cross):
m.d.av_comb += raw_instrs[0].pc.eq(params.pc_from_fb(fetch_block_addr, 0) - 2)
with m.If(s1_data.access_fault):
# Mark that access fault happened only at second (current) half.
# If fault happened on the first half `instr_block_cross` would be false
m.d.av_comb += raw_instrs[0].access_fault.eq(
FetchLayouts.AccessFaultFlag.ACCESS_FAULT_ON_SECOND_HALF
)

with condition(m) as branch:
with branch(flushing_counter == 0):
with m.If(access_fault | unsafe_stall):
# TODO: Raise different code for page fault when supported
# could be passed in 3rd bit of access_fault
flush()
m.d.sync += stalled_unsafe.eq(1)
with m.Elif(redirect):
Expand Down Expand Up @@ -523,6 +532,7 @@ def elaborate(self, platform):

@def_method(m, self.predecode)
def _(instr):
quadrant = instr[0:2]
opcode = instr[2:7]
funct3 = instr[12:15]
rd = instr[7:12]
Expand Down Expand Up @@ -557,6 +567,9 @@ def _(instr):
with m.Default():
m.d.av_comb += ret.cfi_type.eq(CfiType.INVALID)

with m.If(quadrant != 0b11):
m.d.av_comb += ret.cfi_type.eq(CfiType.INVALID)

m.d.av_comb += ret.unsafe.eq(
(opcode == Opcode.SYSTEM) | ((opcode == Opcode.MISC_MEM) & (funct3 == Funct3.FENCEI))
)
Expand Down
20 changes: 18 additions & 2 deletions coreblocks/func_blocks/csr/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transactron.utils.dependencies import DependencyContext

from coreblocks.arch import OpType, Funct3, ExceptionCause, PrivilegeLevel
from coreblocks.arch.isa_consts import Opcode
from coreblocks.params import GenParams
from coreblocks.params.fu_params import BlockComponentParams
from coreblocks.func_blocks.interface.func_protocols import FuncBlock
Expand Down Expand Up @@ -189,7 +190,7 @@ def _(rs_entry_id, rs_data):
m.d.sync += assign(instr, rs_data)

with m.If(rs_data.exec_fn.op_type == OpType.CSR_IMM): # Pass immediate as first operand
m.d.sync += instr.s1_val.eq(rs_data.imm)
m.d.sync += instr.s1_val.eq(rs_data.imm[0:5])

m.d.sync += instr.valid.eq(1)

Expand All @@ -209,7 +210,21 @@ def _():
interrupt = self.dependency_manager.get_dependency(AsyncInterruptInsertSignalKey())

with m.If(exception):
report(m, rob_id=instr.rob_id, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=instr.pc)
mtval = Signal(self.gen_params.isa.xlen)
# re-encode the CSR instruction to speed-up missing CSR emulation (optional, otherwise mtval must be 0)
m.d.av_comb += mtval[0:2].eq(0b11)
m.d.av_comb += mtval[2:7].eq(Opcode.SYSTEM)
m.d.av_comb += mtval[7:12].eq(instr.imm[32 - self.gen_params.isa.reg_cnt_log : 32]) # rl_rd
m.d.av_comb += mtval[12:15].eq(instr.exec_fn.funct3)
m.d.av_comb += mtval[15:20].eq(
Mux(
instr.exec_fn.op_type == OpType.CSR_IMM,
instr.imm[0:5],
instr.imm[32 - self.gen_params.isa.reg_cnt_log * 2 : 32 - self.gen_params.isa.reg_cnt_log],
)
) # rl_s1 or imm
m.d.av_comb += mtval[20:32].eq(instr.csr)
report(m, rob_id=instr.rob_id, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=instr.pc, mtval=mtval)
with m.Elif(interrupt):
# SPEC: "These conditions for an interrupt trap to occur [..] must also be evaluated immediately
# following [..] an explicit write to a CSR on which these interrupt trap conditions expressly depend."
Expand All @@ -221,6 +236,7 @@ def _():
rob_id=instr.rob_id,
cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT,
pc=instr.pc + self.gen_params.isa.ilen_bytes,
mtval=0,
)

m.d.sync += exception.eq(0)
Expand Down
29 changes: 20 additions & 9 deletions coreblocks/func_blocks/fu/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from coreblocks.params import GenParams, FunctionalComponentParams
from coreblocks.arch import OpType, Funct3, ExceptionCause
from coreblocks.interface.layouts import FuncUnitLayouts
from coreblocks.interface.layouts import FetchLayouts, FuncUnitLayouts
from transactron.utils import OneHotSwitch
from coreblocks.interface.keys import ExceptionReportKey, CSRInstancesKey

Expand Down Expand Up @@ -69,28 +69,39 @@ def _(arg):
m.d.comb += decoder.exec_fn.eq(arg.exec_fn)

cause = Signal(ExceptionCause)
mtval = Signal(self.gen_params.isa.xlen)

priv_level = self.dm.get_dependency(CSRInstancesKey()).m_mode.priv_mode.read(m).data

with OneHotSwitch(m, decoder.decode_fn) as OneHotCase:
with OneHotCase(ExceptionUnitFn.Fn.EBREAK):
m.d.comb += cause.eq(ExceptionCause.BREAKPOINT)
m.d.av_comb += cause.eq(ExceptionCause.BREAKPOINT)
m.d.av_comb += mtval.eq(arg.pc)
with OneHotCase(ExceptionUnitFn.Fn.ECALL):
with m.Switch(priv_level):
with m.Case(PrivilegeLevel.MACHINE):
m.d.comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_M)
m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_M)
with m.Case(PrivilegeLevel.USER):
m.d.comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_U)
m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_U)
m.d.av_comb += mtval.eq(0) # by SPEC
with OneHotCase(ExceptionUnitFn.Fn.INSTR_ACCESS_FAULT):
m.d.comb += cause.eq(ExceptionCause.INSTRUCTION_ACCESS_FAULT)
m.d.av_comb += cause.eq(ExceptionCause.INSTRUCTION_ACCESS_FAULT)
# With C extension access fault can be only on the second half of instruction, and mepc != mtval.
# This information is passed in imm field
m.d.av_comb += mtval.eq(
arg.pc + ((arg.imm & FetchLayouts.AccessFaultFlag.ACCESS_FAULT_ON_SECOND_HALF).any() << 1)
)
with OneHotCase(ExceptionUnitFn.Fn.ILLEGAL_INSTRUCTION):
m.d.comb += cause.eq(ExceptionCause.ILLEGAL_INSTRUCTION)
m.d.av_comb += cause.eq(ExceptionCause.ILLEGAL_INSTRUCTION)
m.d.av_comb += mtval.eq(arg.imm) # passed instruction bytes
with OneHotCase(ExceptionUnitFn.Fn.BREAKPOINT):
m.d.comb += cause.eq(ExceptionCause.BREAKPOINT)
m.d.av_comb += cause.eq(ExceptionCause.BREAKPOINT)
m.d.av_comb += mtval.eq(arg.pc)
with OneHotCase(ExceptionUnitFn.Fn.INSTR_PAGE_FAULT):
m.d.comb += cause.eq(ExceptionCause.INSTRUCTION_PAGE_FAULT)
m.d.av_comb += cause.eq(ExceptionCause.INSTRUCTION_PAGE_FAULT)
m.d.av_comb += mtval.eq(arg.pc + (arg.imm[1] << 1))

self.report(m, rob_id=arg.rob_id, cause=cause, pc=arg.pc)
self.report(m, rob_id=arg.rob_id, cause=cause, pc=arg.pc, mtval=mtval)

fifo.write(m, result=0, exception=1, rob_id=arg.rob_id, rp_dst=arg.rp_dst)

Expand Down
13 changes: 10 additions & 3 deletions coreblocks/func_blocks/fu/jumpbranch.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,22 +214,29 @@ def _():
# generated for a conditional branch that is not taken."
m.d.comb += exception.eq(1)
exception_report(
m, rob_id=instr.rob_id, cause=ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED, pc=instr.pc
m,
rob_id=instr.rob_id,
cause=ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED,
pc=instr.pc,
mtval=instr.jmp_addr,
)

with m.Elif(async_interrupt_active & ~is_auipc):
# Jump instructions are entry points for async interrupts.
# This way we can store known pc via report to global exception register and avoid it in ROB.
# Exceptions have priority, because the instruction that reports async interrupt is commited
# and exception would be lost.
m.d.comb += exception.eq(1)
exception_report(
m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=jump_result
m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=jump_result, mtval=0
)
with m.Elif(misprediction):
# Async interrupts can have priority, because `jump_result` is handled in the same way.
# No extra misprediction penalty will be introducted at interrupt return to `jump_result` address.
m.d.comb += exception.eq(1)
exception_report(m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_MISPREDICTION, pc=jump_result)
exception_report(
m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_MISPREDICTION, pc=jump_result, mtval=0
)

with m.If(~is_auipc):
self.fifo_branch_resolved.write(m, from_pc=instr.pc, next_pc=jump_result, misprediction=misprediction)
Expand Down
8 changes: 4 additions & 4 deletions coreblocks/func_blocks/fu/lsu/dummyLsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _(arg):
with m.If(~is_fence):
requests.write(m, arg)
with m.Else():
results_noop.write(m, data=0, exception=0, cause=0)
results_noop.write(m, data=0, exception=0, cause=0, addr=0)
issued_noop.write(m, arg)

# Issues load/store requests when the instruction is known, is a LOAD/STORE, and just before commit.
Expand Down Expand Up @@ -106,14 +106,14 @@ def _(arg):

with m.If(res["exception"]):
issued_noop.write(m, arg)
results_noop.write(m, data=0, exception=res["exception"], cause=res["cause"])
results_noop.write(m, data=0, exception=res["exception"], cause=res["cause"], addr=addr)
with m.Else():
issued.write(m, arg)

# Handles flushed instructions as a no-op.
with Transaction().body(m, request=flush):
arg = requests.read(m)
results_noop.write(m, data=0, exception=0, cause=0)
results_noop.write(m, data=0, exception=0, cause=0, addr=0)
issued_noop.write(m, arg)

@def_method(m, self.accept)
Expand All @@ -129,7 +129,7 @@ def _():
m.d.comb += arg.eq(issued_noop.read(m))

with m.If(res["exception"]):
self.report(m, rob_id=arg["rob_id"], cause=res["cause"], pc=arg["pc"])
self.report(m, rob_id=arg["rob_id"], cause=res["cause"], pc=arg["pc"], mtval=res["addr"])

self.log.debug(m, 1, "accept rob_id={} result=0x{:08x} exception={}", arg.rob_id, res.data, res.exception)

Expand Down
2 changes: 1 addition & 1 deletion coreblocks/func_blocks/fu/lsu/lsu_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,6 @@ def _():
Mux(request_args.store, ExceptionCause.STORE_ACCESS_FAULT, ExceptionCause.LOAD_ACCESS_FAULT)
)

return {"data": data, "exception": exception, "cause": cause}
return {"data": data, "exception": exception, "cause": cause, "addr": request_args.addr}

return m
Loading

0 comments on commit 8f2072f

Please sign in to comment.