Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement full mtval #712

Merged
merged 20 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions coreblocks/backend/retirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def flush_instr(rob_entry):
# Register RISC-V architectural trap in CSRs
m_csr.mcause.write(m, cause_entry)
m_csr.mepc.write(m, cause_register.pc)
m_csr.mtval.write(m, cause_register.mtval)
self.trap_entry(m)

# Fetch is already stalled by ExceptionCauseRegister
Expand Down
10 changes: 5 additions & 5 deletions coreblocks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from coreblocks.core_structs.rf import RegisterFile
from coreblocks.priv.csr.csr_instances import GenericCSRRegisters
from coreblocks.frontend.frontend import CoreFrontend
from coreblocks.priv.traps.exception import ExceptionCauseRegister
from coreblocks.priv.traps.exception import ExceptionInformationRegister
from coreblocks.scheduler.scheduler import Scheduler
from coreblocks.backend.annoucement import ResultAnnouncement
from coreblocks.backend.retirement import Retirement
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, *, gen_params: GenParams):

self.connections.add_dependency(CommonBusDataKey(), self.bus_master_data_adapter)

self.exception_cause_register = ExceptionCauseRegister(
self.exception_information_register = ExceptionInformationRegister(
self.gen_params,
rob_get_indices=self.ROB.get_indices,
fetch_stall_exception=self.frontend.stall,
Expand Down Expand Up @@ -134,7 +134,7 @@ def elaborate(self, platform):
gen_params=self.gen_params,
)

m.submodules.exception_cause_register = self.exception_cause_register
m.submodules.exception_information_register = self.exception_information_register

fetch_resume_fb, fetch_resume_unifiers = self.connections.get_dependency(FetchResumeKey())
m.submodules.fetch_resume_unifiers = ModuleConnector(**fetch_resume_unifiers)
Expand All @@ -151,8 +151,8 @@ def elaborate(self, platform):
r_rat_peek=rrat.peek,
free_rf_put=free_rf_fifo.write,
rf_free=rf.free,
exception_cause_get=self.exception_cause_register.get,
exception_cause_clear=self.exception_cause_register.clear,
exception_cause_get=self.exception_information_register.get,
exception_cause_clear=self.exception_information_register.clear,
frat_rename=frat.rename,
fetch_continue=self.frontend.resume_from_exception,
instr_decrement=core_counter.decrement,
Expand Down
6 changes: 4 additions & 2 deletions coreblocks/frontend/decoder/decode_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def elaborate(self, platform):
]

exception_override = Signal()
m.d.comb += exception_override.eq(instr_decoder.illegal | raw.access_fault)
m.d.comb += exception_override.eq(instr_decoder.illegal | raw.access_fault.any())
exception_funct = Signal(Funct3)
with m.If(raw.access_fault):
m.d.comb += exception_funct.eq(Funct3._EINSTRACCESSFAULT)
Expand Down Expand Up @@ -95,7 +95,9 @@ def elaborate(self, platform):
"rl_s1": Mux(instr_decoder.rs1_v & (~exception_override), instr_decoder.rs1, 0),
"rl_s2": Mux(instr_decoder.rs2_v & (~exception_override), instr_decoder.rs2, 0),
},
"imm": instr_decoder.imm,
"imm": Mux(
~exception_override, instr_decoder.imm, Mux(raw.access_fault, raw.access_fault, raw.instr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be raw.pc in case when raw.access_fault is true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is correct, Exception FU uses information from access_fault to determine correct PC. Switched to enum and documented this mux

),
"csr": instr_decoder.csr,
"pc": raw.pc,
},
Expand Down
8 changes: 8 additions & 0 deletions coreblocks/frontend/decoder/instr_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,14 @@ def elaborate(self, platform):
self.rs1_v.eq(0),
]

# HACK: pass logical registers in unused high bits of CSR instruction for `mtval` reconstruction
with m.If((self.optype == OpType.CSR_REG) | (self.optype == OpType.CSR_IMM)):
m.d.comb += self.imm[32 - self.gen_params.isa.reg_cnt_log : 32].eq(self.rd)
m.d.comb += self.imm[32 - self.gen_params.isa.reg_cnt_log * 2 : 32 - self.gen_params.isa.reg_cnt_log].eq(
self.rs1
)
assert 32 - self.gen_params.isa.reg_cnt_log * 2 >= 5

# Instruction simplification

# lui rd, imm -> addi rd, x0, (imm << 12)
Expand Down
4 changes: 3 additions & 1 deletion coreblocks/frontend/decoder/rvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def elaborate(self, platform):

res = self.instr_mux(quadrant, quadrants)

m.d.comb += self.instr_out.eq(Mux(res[1], res[0], IllegalInstr()))
# In case of illegal instruction, output `instr_in` to be able to save it into `mtval` CSR.
# Decoder would still recognize it as illegal because of quadrant != 0b11
m.d.comb += self.instr_out.eq(Mux(res[1], res[0], self.instr_in))

return m
9 changes: 9 additions & 0 deletions coreblocks/frontend/fetch/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,16 @@ def flush():
if Extension.C in self.gen_params.isa.extensions:
with m.If(s1_data.instr_block_cross):
m.d.av_comb += raw_instrs[0].pc.eq(params.pc_from_fb(fetch_block_addr, 0) - 2)
with m.If(s1_data.access_fault):
m.d.av_comb += raw_instrs[0].access_fault.eq(
0b10
) # Mark that access fault happened only at second half

with condition(m) as branch:
with branch(flushing_counter == 0):
with m.If(access_fault | unsafe_stall):
# TODO: Raise different code for page fault when supported
# could be passed in 3rd bit of access_fault
flush()
m.d.sync += stalled_unsafe.eq(1)
with m.Elif(redirect):
Expand Down Expand Up @@ -523,6 +528,7 @@ def elaborate(self, platform):

@def_method(m, self.predecode)
def _(instr):
quadrant = instr[0:2]
opcode = instr[2:7]
funct3 = instr[12:15]
rd = instr[7:12]
Expand Down Expand Up @@ -557,6 +563,9 @@ def _(instr):
with m.Default():
m.d.av_comb += ret.cfi_type.eq(CfiType.INVALID)

with m.If(quadrant != 0b11):
m.d.av_comb += ret.cfi_type.eq(CfiType.INVALID)

m.d.av_comb += ret.unsafe.eq(
(opcode == Opcode.SYSTEM) | ((opcode == Opcode.MISC_MEM) & (funct3 == Funct3.FENCEI))
)
Expand Down
20 changes: 18 additions & 2 deletions coreblocks/func_blocks/csr/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transactron.utils.dependencies import DependencyContext

from coreblocks.arch import OpType, Funct3, ExceptionCause, PrivilegeLevel
from coreblocks.arch.isa_consts import Opcode
from coreblocks.params import GenParams
from coreblocks.params.fu_params import BlockComponentParams
from coreblocks.func_blocks.interface.func_protocols import FuncBlock
Expand Down Expand Up @@ -188,7 +189,7 @@ def _(rs_entry_id, rs_data):
m.d.sync += assign(instr, rs_data)

with m.If(rs_data.exec_fn.op_type == OpType.CSR_IMM): # Pass immediate as first operand
m.d.sync += instr.s1_val.eq(rs_data.imm)
m.d.sync += instr.s1_val.eq(rs_data.imm[0:5])

m.d.sync += instr.valid.eq(1)

Expand All @@ -208,7 +209,21 @@ def _():
interrupt = self.dependency_manager.get_dependency(AsyncInterruptInsertSignalKey())

with m.If(exception):
report(m, rob_id=instr.rob_id, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=instr.pc)
mtval = Signal(self.gen_params.isa.xlen)
# re-encode the CSR instruction to speed-up missing CSR emulation (optional, otherwise mtval must be 0)
m.d.av_comb += mtval[0:2].eq(0b11)
m.d.av_comb += mtval[2:7].eq(Opcode.SYSTEM)
m.d.av_comb += mtval[7:12].eq(instr.imm[32 - self.gen_params.isa.reg_cnt_log : 32]) # rl_rd
m.d.av_comb += mtval[12:15].eq(instr.exec_fn.funct3)
m.d.av_comb += mtval[15:20].eq(
Mux(
instr.exec_fn.op_type == OpType.CSR_IMM,
instr.imm[0:5],
instr.imm[32 - self.gen_params.isa.reg_cnt_log * 2 : 32 - self.gen_params.isa.reg_cnt_log],
)
) # rl_s1 or imm
m.d.av_comb += mtval[20:32].eq(instr.csr)
report(m, rob_id=instr.rob_id, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=instr.pc, mtval=mtval)
with m.Elif(interrupt):
# SPEC: "These conditions for an interrupt trap to occur [..] must also be evaluated immediately
# following [..] an explicit write to a CSR on which these interrupt trap conditions expressly depend."
Expand All @@ -220,6 +235,7 @@ def _():
rob_id=instr.rob_id,
cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT,
pc=instr.pc + self.gen_params.isa.ilen_bytes,
mtval=0,
)

m.d.sync += exception.eq(0)
Expand Down
23 changes: 16 additions & 7 deletions coreblocks/func_blocks/fu/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,32 @@ def _(arg):
m.d.comb += decoder.exec_fn.eq(arg.exec_fn)

cause = Signal(ExceptionCause)
mtval = Signal(self.gen_params.isa.xlen)

with OneHotSwitch(m, decoder.decode_fn) as OneHotCase:
with OneHotCase(ExceptionUnitFn.Fn.EBREAK):
m.d.comb += cause.eq(ExceptionCause.BREAKPOINT)
m.d.av_comb += cause.eq(ExceptionCause.BREAKPOINT)
m.d.av_comb += mtval.eq(arg.pc)
with OneHotCase(ExceptionUnitFn.Fn.ECALL):
# TODO: Switch privilege level when implemented
m.d.comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_M)
m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_M)
m.d.av_comb += mtval.eq(0) # by SPEC
with OneHotCase(ExceptionUnitFn.Fn.INSTR_ACCESS_FAULT):
m.d.comb += cause.eq(ExceptionCause.INSTRUCTION_ACCESS_FAULT)
m.d.av_comb += cause.eq(ExceptionCause.INSTRUCTION_ACCESS_FAULT)
# With C extension access fault can be only on the second half of instruction, and mepc != mtval.
# This information is passed in imm field
m.d.av_comb += mtval.eq(arg.pc + (arg.imm[1] << 1))
with OneHotCase(ExceptionUnitFn.Fn.ILLEGAL_INSTRUCTION):
m.d.comb += cause.eq(ExceptionCause.ILLEGAL_INSTRUCTION)
m.d.av_comb += cause.eq(ExceptionCause.ILLEGAL_INSTRUCTION)
m.d.av_comb += mtval.eq(arg.imm) # passed instruction bytes
with OneHotCase(ExceptionUnitFn.Fn.BREAKPOINT):
m.d.comb += cause.eq(ExceptionCause.BREAKPOINT)
m.d.av_comb += cause.eq(ExceptionCause.BREAKPOINT)
m.d.av_comb += mtval.eq(arg.pc)
with OneHotCase(ExceptionUnitFn.Fn.INSTR_PAGE_FAULT):
m.d.comb += cause.eq(ExceptionCause.INSTRUCTION_PAGE_FAULT)
m.d.av_comb += cause.eq(ExceptionCause.INSTRUCTION_PAGE_FAULT)
m.d.av_comb += mtval.eq(arg.pc + (arg.imm[1] << 1))

self.report(m, rob_id=arg.rob_id, cause=cause, pc=arg.pc)
self.report(m, rob_id=arg.rob_id, cause=cause, pc=arg.pc, mtval=mtval)

fifo.write(m, result=0, exception=1, rob_id=arg.rob_id, rp_dst=arg.rp_dst)

Expand Down
13 changes: 10 additions & 3 deletions coreblocks/func_blocks/fu/jumpbranch.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,22 +214,29 @@ def _():
# generated for a conditional branch that is not taken."
m.d.comb += exception.eq(1)
exception_report(
m, rob_id=instr.rob_id, cause=ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED, pc=instr.pc
m,
rob_id=instr.rob_id,
cause=ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED,
pc=instr.pc,
mtval=instr.jmp_addr,
)

with m.Elif(async_interrupt_active & ~is_auipc):
# Jump instructions are entry points for async interrupts.
# This way we can store known pc via report to global exception register and avoid it in ROB.
# Exceptions have priority, because the instruction that reports async interrupt is commited
# and exception would be lost.
m.d.comb += exception.eq(1)
exception_report(
m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=jump_result
m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=jump_result, mtval=0
)
with m.Elif(misprediction):
# Async interrupts can have priority, because `jump_result` is handled in the same way.
# No extra misprediction penalty will be introducted at interrupt return to `jump_result` address.
m.d.comb += exception.eq(1)
exception_report(m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_MISPREDICTION, pc=jump_result)
exception_report(
m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_MISPREDICTION, pc=jump_result, mtval=0
)

with m.If(~is_auipc):
self.fifo_branch_resolved.write(m, from_pc=instr.pc, next_pc=jump_result, misprediction=misprediction)
Expand Down
8 changes: 4 additions & 4 deletions coreblocks/func_blocks/fu/lsu/dummyLsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _(arg):
with m.If(~is_fence):
requests.write(m, arg)
with m.Else():
results_noop.write(m, data=0, exception=0, cause=0)
results_noop.write(m, data=0, exception=0, cause=0, addr=0)
issued_noop.write(m, arg)

# Issues load/store requests when the instruction is known, is a LOAD/STORE, and just before commit.
Expand Down Expand Up @@ -106,14 +106,14 @@ def _(arg):

with m.If(res["exception"]):
issued_noop.write(m, arg)
results_noop.write(m, data=0, exception=res["exception"], cause=res["cause"])
results_noop.write(m, data=0, exception=res["exception"], cause=res["cause"], addr=addr)
with m.Else():
issued.write(m, arg)

# Handles flushed instructions as a no-op.
with Transaction().body(m, request=flush):
arg = requests.read(m)
results_noop.write(m, data=0, exception=0, cause=0)
results_noop.write(m, data=0, exception=0, cause=0, addr=0)
issued_noop.write(m, arg)

@def_method(m, self.accept)
Expand All @@ -129,7 +129,7 @@ def _():
m.d.comb += arg.eq(issued_noop.read(m))

with m.If(res["exception"]):
self.report(m, rob_id=arg["rob_id"], cause=res["cause"], pc=arg["pc"])
self.report(m, rob_id=arg["rob_id"], cause=res["cause"], pc=arg["pc"], mtval=res["addr"])

self.log.debug(m, 1, "accept rob_id={} result=0x{:08x} exception={}", arg.rob_id, res.data, res.exception)

Expand Down
2 changes: 1 addition & 1 deletion coreblocks/func_blocks/fu/lsu/lsu_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,6 @@ def _():
Mux(request_args.store, ExceptionCause.STORE_ACCESS_FAULT, ExceptionCause.LOAD_ACCESS_FAULT)
)

return {"data": data, "exception": exception, "cause": cause}
return {"data": data, "exception": exception, "cause": cause, "addr": request_args.addr}

return m
4 changes: 3 additions & 1 deletion coreblocks/func_blocks/fu/priv.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def _():
# Interrupt is reported on this xRET instruction with return address set to instruction that we
# would normally return to (mepc value is preserved)
m.d.comb += exception.eq(1)
exception_report(m, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=ret_pc, rob_id=instr_rob)
exception_report(
m, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=ret_pc, rob_id=instr_rob, mtval=0
)
with m.Else():
log.info(m, True, "Unstalling fetch from the priv unit new_pc=0x{:x}", ret_pc)
# Unstall the fetch
Expand Down
16 changes: 10 additions & 6 deletions coreblocks/interface/layouts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional
from amaranth import signed
from amaranth.lib.data import StructLayout, ArrayLayout
from amaranth.lib.data import ArrayLayout
from coreblocks.params import GenParams
from coreblocks.arch import *
from transactron.utils import LayoutList, LayoutListField, layout_subset
from transactron.utils.transactron_helpers import from_method_layout, make_layout
from transactron.utils.transactron_helpers import from_method_layout, make_layout, extend_layout

__all__ = [
"CommonLayoutFields",
Expand Down Expand Up @@ -444,7 +444,7 @@ class FetchLayouts:
def __init__(self, gen_params: GenParams):
fields = gen_params.get(CommonLayoutFields)

self.access_fault: LayoutListField = ("access_fault", 1)
self.access_fault: LayoutListField = ("access_fault", 2)
piotro888 marked this conversation as resolved.
Show resolved Hide resolved
"""Instruction fetch failed."""

self.raw_instr = make_layout(
Expand Down Expand Up @@ -581,7 +581,7 @@ def __init__(self, gen_params: GenParams):

self.issue_out = make_layout(fields.exception, fields.cause)

self.accept = make_layout(fields.data, fields.exception, fields.cause)
self.accept = make_layout(fields.data, fields.exception, fields.cause, fields.addr)


class CSRRegisterLayouts:
Expand Down Expand Up @@ -634,20 +634,24 @@ def __init__(self, gen_params: GenParams):


class ExceptionRegisterLayouts:
"""Layouts used in the exception register."""
"""Layouts used in the exception information register."""

def __init__(self, gen_params: GenParams):
fields = gen_params.get(CommonLayoutFields)

self.mtval: LayoutListField = ("mtval", gen_params.isa.xlen)
""" Value to set for mtval CSR register """

self.valid: LayoutListField = ("valid", 1)

self.report = make_layout(
fields.cause,
fields.rob_id,
fields.pc,
self.mtval,
)

self.get = StructLayout(self.report.members | make_layout(self.valid).members)
self.get = extend_layout(self.report, self.valid)


class InternalInterruptControllerLayouts:
Expand Down
2 changes: 2 additions & 0 deletions coreblocks/priv/csr/csr_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def __init__(self, gen_params: GenParams):
mepc_ro_bits = 0b1 if Extension.C in gen_params.isa.extensions else 0b11 # pc alignment (SPEC)
self.mepc = CSRRegister(CSRAddress.MEPC, gen_params, ro_bits=mepc_ro_bits)

self.mtval = CSRRegister(CSRAddress.MTVAL, gen_params)

def elaborate(self, platform):
m = Module()

Expand Down
Loading