Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix WFI resume on mstatus.MIE disabled #735

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion coreblocks/func_blocks/fu/priv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
InstructionPrecommitKey,
FetchResumeKey,
FlushICacheKey,
WaitForInterruptResumeKey,
)
from coreblocks.func_blocks.interface.func_protocols import FuncUnit

Expand Down Expand Up @@ -80,6 +81,7 @@ def elaborate(self, platform):

mret = self.dm.get_dependency(MretKey())
async_interrupt_active = self.dm.get_dependency(AsyncInterruptInsertSignalKey())
wfi_resume = self.dm.get_dependency(WaitForInterruptResumeKey())
exception_report = self.dm.get_dependency(ExceptionReportKey())
csr = self.dm.get_dependency(CSRInstancesKey())
priv_mode = csr.m_mode.priv_mode
Expand Down Expand Up @@ -120,7 +122,9 @@ def _(arg):
with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.FENCEI)):
flush_icache(m)
with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.WFI) & ~illegal_wfi):
m.d.sync += finished.eq(async_interrupt_active)
# async_interrupt_active implies wfi_resume. WFI should continue normal execution
# when interrupt is enabled in xie, but disabled via global mstatus.xIE
m.d.sync += finished.eq(wfi_resume)

m.d.sync += illegal_instruction.eq(illegal_wfi | illegal_mret)

Expand Down
5 changes: 5 additions & 0 deletions coreblocks/interface/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ class AsyncInterruptInsertSignalKey(SimpleKey[Signal]):
pass


@dataclass(frozen=True)
class WaitForInterruptResumeKey(SimpleKey[Signal]):
pass


@dataclass(frozen=True)
class MretKey(SimpleKey[Method]):
pass
Expand Down
13 changes: 12 additions & 1 deletion coreblocks/priv/traps/interrupt_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from coreblocks.interface.layouts import InternalInterruptControllerLayouts
from coreblocks.priv.csr.csr_register import CSRRegister
from coreblocks.params.genparams import GenParams
from coreblocks.interface.keys import AsyncInterruptInsertSignalKey, CSRInstancesKey, MretKey
from coreblocks.interface.keys import (
AsyncInterruptInsertSignalKey,
CSRInstancesKey,
MretKey,
WaitForInterruptResumeKey,
)

from transactron.core import Method, TModule, def_method
from transactron.core.transaction import Transaction
Expand Down Expand Up @@ -78,6 +83,9 @@ def __init__(self, gen_params: GenParams):
self.interrupt_insert = Signal()
self.dm.add_dependency(AsyncInterruptInsertSignalKey(), self.interrupt_insert)

self.wfi_resume = Signal()
self.dm.add_dependency(WaitForInterruptResumeKey(), self.wfi_resume)

self.interrupt_cause = Method(o=gen_params.get(InternalInterruptControllerLayouts).interrupt_cause)

self.mret = Method()
Expand Down Expand Up @@ -108,6 +116,9 @@ def elaborate(self, platform):
interrupt_pending = (mie & mip).any()
m.d.comb += self.interrupt_insert.eq(interrupt_pending & interrupt_enable)

# WFI is independent of global mstatus.xIE and mideleg
m.d.comb += self.wfi_resume.eq(interrupt_pending)

edge_report_interrupt = Signal(self.gen_params.isa.xlen)
level_report_interrupt = Signal(self.gen_params.isa.xlen)
m.d.comb += edge_report_interrupt.eq((self.custom_report << ISA_RESERVED_INTERRUPTS) & self.edge_reported_mask)
Expand Down
2 changes: 1 addition & 1 deletion test/asm/user_mode.asm
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ case0:
li x1, 1<<17
csrs mie, x1 # enable fixed level interrupt

# MIE = 0, but interrupts are active in U-MODE (when enabled in mie)
# mstatus.MIE = 0, but interrupts are active in U-MODE (when enabled in mie)

la x1, user_code2
csrw mepc, x1
Expand Down
27 changes: 27 additions & 0 deletions test/asm/wfi_no_mie.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_start:
li x4, 0

la x1, trap_handler
csrw mtvec, x1

li x1, 1<<17
csrs mie, x1 # enable fixed level interrupt
# but keep mstatus.MIE disabled

li x2, 1
loop:
wfi
addi x2, x2, -1
bnez x2, loop

j pass

fail:
j fail

pass:
li x8, 8
j pass

trap_handler:
j fail
14 changes: 11 additions & 3 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,17 @@ def test_interrupted_prog(self):


@parameterized_class(
("source_file", "cycle_count", "expected_regvals"),
("source_file", "cycle_count", "expected_regvals", "always_mmode"),
[
("user_mode.asm", 1000, {4: 5}),
("user_mode.asm", 1000, {4: 5}, False),
("wfi_no_mie.asm", 250, {8: 8}, True), # only using level enable
],
)
class TestCoreInterruptOnPrivMode(TestCoreAsmSourceBase):
source_file: str
cycle_count: int
expected_regvals: dict[int, int]
always_mmode: bool

def setup_method(self):
self.configuration = full_core_config.replace(
Expand All @@ -319,12 +321,16 @@ def run_with_interrupt_process(self):
while (yield self.m.core.interrupt_controller.mie.value) == 0 and cycles < self.cycle_count:
cycles += 1
yield Tick()
yield from self.random_wait(5)

while cycles < self.cycle_count:
yield from self.random_wait(5)
yield self.m.interrupt_level.eq(1)
cycles += 1
yield Tick()

if self.always_mmode: # if test happens only in m_mode, just enable fixed interrupt
continue

# wait for the interrupt to get registered
while (
yield self.m.core.csr_generic.m_mode.priv_mode.value
Expand All @@ -342,6 +348,8 @@ def run_with_interrupt_process(self):
cycles += 1
yield Tick()

yield from self.random_wait(5)

for reg_id, val in self.expected_regvals.items():
assert (yield from self.get_arch_reg_val(reg_id)) == val

Expand Down