From cf30bbe30f3126f8c5efbf5f4c943f2f6aa43c61 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Tue, 2 Jan 2024 13:26:10 +0100 Subject: [PATCH] Exclusive branches (#551) --- coreblocks/fu/jumpbranch.py | 17 +-- coreblocks/structs_common/csr.py | 22 +-- test/transactions/test_branches.py | 98 ++++++++++++ test/transactions/test_methods.py | 23 +++ transactron/core.py | 230 +++++++++++++++++++++++++---- 5 files changed, 329 insertions(+), 61 deletions(-) create mode 100644 test/transactions/test_branches.py diff --git a/coreblocks/fu/jumpbranch.py b/coreblocks/fu/jumpbranch.py index c90e6f77e..247ba9441 100644 --- a/coreblocks/fu/jumpbranch.py +++ b/coreblocks/fu/jumpbranch.py @@ -3,12 +3,10 @@ from enum import IntFlag, auto from typing import Sequence -from coreblocks.params.layouts import ExceptionRegisterLayouts from transactron import * from transactron.core import def_method from transactron.lib import * -from transactron.utils import assign from coreblocks.params import * from coreblocks.params.keys import AsyncInterruptInsertSignalKey @@ -171,30 +169,19 @@ def _(arg): AsyncInterruptInsertSignalKey() ) - exception_entry = Record(self.gen_params.get(ExceptionRegisterLayouts).report) - with m.If(~is_auipc & jb.taken & jmp_addr_misaligned): # Spec: "[...] if the target address is not four-byte aligned. This exception is reported on the branch # or jump instruction, not on the target instruction. No instruction-address-misaligned exception is # generated for a conditional branch that is not taken." m.d.comb += exception.eq(1) - m.d.comb += assign( - exception_entry, - {"rob_id": arg.rob_id, "cause": ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED, "pc": arg.pc}, - ) + exception_report(m, rob_id=arg.rob_id, cause=ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED, pc=arg.pc) with m.Elif(async_interrupt_active & ~is_auipc): # Jump instructions are entry points for async interrupts. # This way we can store known pc via report to global exception register and avoid it in ROB. # Exceptions have priority, because the instruction that reports async interrupt is commited # and exception would be lost. m.d.comb += exception.eq(1) - m.d.comb += assign( - exception_entry, - {"rob_id": arg.rob_id, "cause": ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, "pc": jump_result}, - ) - - with m.If(exception): - exception_report(m, exception_entry) + exception_report(m, rob_id=arg.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=jump_result) fifo_res.write(m, rob_id=arg.rob_id, result=jb.reg_res, rp_dst=arg.rp_dst, exception=exception) diff --git a/coreblocks/structs_common/csr.py b/coreblocks/structs_common/csr.py index e4bad70e8..bebb5697b 100644 --- a/coreblocks/structs_common/csr.py +++ b/coreblocks/structs_common/csr.py @@ -7,7 +7,7 @@ from coreblocks.params.genparams import GenParams from coreblocks.params.dependencies import DependencyManager, ListKey from coreblocks.params.fu_params import BlockComponentParams -from coreblocks.params.layouts import ExceptionRegisterLayouts, FetchLayouts, FuncUnitLayouts, CSRLayouts +from coreblocks.params.layouts import FetchLayouts, FuncUnitLayouts, CSRLayouts from coreblocks.params.isa import Funct3, ExceptionCause from coreblocks.params.keys import ( AsyncInterruptInsertSignalKey, @@ -333,29 +333,21 @@ def _(): report = self.dependency_manager.get_dependency(ExceptionReportKey()) interrupt = self.dependency_manager.get_dependency(AsyncInterruptInsertSignalKey()) - exception_entry = Record(self.gen_params.get(ExceptionRegisterLayouts).report) with m.If(exception): - m.d.comb += assign( - exception_entry, - {"rob_id": instr.rob_id, "cause": ExceptionCause.ILLEGAL_INSTRUCTION, "pc": instr.pc}, - ) + report(m, rob_id=instr.rob_id, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=instr.pc) with m.Elif(interrupt): # SPEC: "These conditions for an interrupt trap to occur [..] must also be evaluated immediately # following [..] an explicit write to a CSR on which these interrupt trap conditions expressly depend." # At this time CSR operation is finished. If it caused triggering an interrupt, it would be represented # by interrupt signal in this cycle. # CSR instructions are never compressed, PC+4 is always next instruction - m.d.comb += assign( - exception_entry, - { - "rob_id": instr.rob_id, - "cause": ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, - "pc": instr.pc + self.gen_params.isa.ilen_bytes, - }, + report( + m, + rob_id=instr.rob_id, + cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, + pc=instr.pc + self.gen_params.isa.ilen_bytes, ) - with m.If(exception | interrupt): - report(m, exception_entry) m.d.sync += exception.eq(0) diff --git a/test/transactions/test_branches.py b/test/transactions/test_branches.py new file mode 100644 index 000000000..ba2a4545a --- /dev/null +++ b/test/transactions/test_branches.py @@ -0,0 +1,98 @@ +from amaranth import * +from itertools import product +from transactron.core import ( + CtrlPath, + MethodMap, + TModule, + Method, + Transaction, + TransactionManager, + TransactionModule, + def_method, +) +from unittest import TestCase +from ..common import TestCaseWithSimulator + + +class TestExclusivePath(TestCase): + def test_exclusive_path(self): + m = TModule() + m._MustUse__silence = True # type: ignore + + with m.If(0): + cp0 = m.ctrl_path + with m.Switch(3): + with m.Case(0): + cp0a0 = m.ctrl_path + with m.Case(1): + cp0a1 = m.ctrl_path + with m.Default(): + cp0a2 = m.ctrl_path + with m.If(1): + cp0b0 = m.ctrl_path + with m.Else(): + cp0b1 = m.ctrl_path + with m.Elif(1): + cp1 = m.ctrl_path + with m.FSM(): + with m.State("start"): + cp10 = m.ctrl_path + with m.State("next"): + cp11 = m.ctrl_path + with m.Else(): + cp2 = m.ctrl_path + + def mutually_exclusive(*cps: CtrlPath): + return all(cpa.exclusive_with(cpb) for i, cpa in enumerate(cps) for cpb in cps[i + 1 :]) + + def pairwise_exclusive(cps1: list[CtrlPath], cps2: list[CtrlPath]): + return all(cpa.exclusive_with(cpb) for cpa, cpb in product(cps1, cps2)) + + def pairwise_not_exclusive(cps1: list[CtrlPath], cps2: list[CtrlPath]): + return all(not cpa.exclusive_with(cpb) for cpa, cpb in product(cps1, cps2)) + + self.assertTrue(mutually_exclusive(cp0, cp1, cp2)) + self.assertTrue(mutually_exclusive(cp0a0, cp0a1, cp0a2)) + self.assertTrue(mutually_exclusive(cp0b0, cp0b1)) + self.assertTrue(mutually_exclusive(cp10, cp11)) + self.assertTrue(pairwise_exclusive([cp0, cp0a0, cp0a1, cp0a2, cp0b0, cp0b1], [cp1, cp10, cp11])) + self.assertTrue(pairwise_not_exclusive([cp0, cp0a0, cp0a1, cp0a2], [cp0, cp0b0, cp0b1])) + + +class ExclusiveConflictRemovalCircuit(Elaboratable): + def __init__(self): + self.sel = Signal() + + def elaborate(self, platform): + m = TModule() + + called_method = Method(i=[], o=[]) + + @def_method(m, called_method) + def _(): + pass + + with m.If(self.sel): + with Transaction().body(m): + called_method(m) + with m.Else(): + with Transaction().body(m): + called_method(m) + + return m + + +class TestExclusiveConflictRemoval(TestCaseWithSimulator): + def test_conflict_removal(self): + circ = ExclusiveConflictRemovalCircuit() + + tm = TransactionManager() + dut = TransactionModule(circ, tm) + + with self.run_simulation(dut): + pass + + cgr, _, _ = tm._conflict_graph(MethodMap(tm.transactions)) + + for s in cgr.values(): + self.assertFalse(s) diff --git a/test/transactions/test_methods.py b/test/transactions/test_methods.py index df05d259f..22c3dab45 100644 --- a/test/transactions/test_methods.py +++ b/test/transactions/test_methods.py @@ -162,6 +162,29 @@ def elaborate(self, platform): self.assert_re("called twice", Twice()) + def test_twice_cond(self): + class Twice(Elaboratable): + def __init__(self): + self.meth1 = Method() + self.meth2 = Method() + + def elaborate(self, platform): + m = TModule() + m._MustUse__silence = True # type: ignore + + with self.meth1.body(m): + pass + + with self.meth2.body(m): + with m.If(1): + self.meth1(m) + with m.Else(): + self.meth1(m) + + return m + + Fragment.get(TransactionModule(Twice()), platform=None) + def test_diamond(self): class Diamond(Elaboratable): def __init__(self): diff --git a/transactron/core.py b/transactron/core.py index ca83f9985..98daa8b76 100644 --- a/transactron/core.py +++ b/transactron/core.py @@ -17,6 +17,7 @@ from os import environ from graphlib import TopologicalSorter from typing_extensions import Self +from dataclasses import dataclass, replace from amaranth import * from amaranth import tracer from itertools import count, chain, filterfalse, product @@ -227,6 +228,18 @@ def _conflict_graph(method_map: MethodMap) -> Tuple[TransactionGraph, Transactio Linear ordering of transactions which is consistent with priority constraints. """ + def transactions_exclusive(trans1: Transaction, trans2: Transaction): + tms1 = [trans1] + method_map.methods_by_transaction[trans1] + tms2 = [trans2] + method_map.methods_by_transaction[trans2] + + # if first transaction is exclusive with the second transaction, or this is true for + # any called methods, the transactions will never run at the same time + for tm1, tm2 in product(tms1, tms2): + if tm1.ctrl_path.exclusive_with(tm2.ctrl_path): + return True + + return False + cgr: TransactionGraph = {} # Conflict graph pgr: TransactionGraph = {} # Priority graph rgr: TransactionGraph = {} # Relation graph @@ -253,7 +266,7 @@ def add_edge(begin: Transaction, end: Transaction, priority: Priority, conflict: continue for transaction1 in method_map.transactions_for(method): for transaction2 in method_map.transactions_for(method): - if transaction1 is not transaction2: + if transaction1 is not transaction2 and not transactions_exclusive(transaction1, transaction2): add_edge(transaction1, transaction2, Priority.UNDEFINED, True) relations = [ @@ -271,7 +284,8 @@ def add_edge(begin: Transaction, end: Transaction, priority: Priority, conflict: for trans_start in method_map.transactions_for(start): for trans_end in method_map.transactions_for(end): - add_edge(trans_start, trans_end, relation["priority"], relation["conflict"]) + conflict = relation["conflict"] and not transactions_exclusive(trans_start, trans_end) + add_edge(trans_start, trans_end, relation["priority"], conflict) porder: PriorityOrder = {} @@ -389,9 +403,11 @@ def maximal(group: frozenset[Transaction]): method.ready = transaction.request method.run = transaction.grant method.defined = transaction.defined + method.method_calls = transaction.method_calls method.method_uses = transaction.method_uses method.relations = transaction.relations method.def_order = transaction.def_order + method.ctrl_path = transaction.ctrl_path methods[transaction] = method for elem in method_map.methods_and_transactions: @@ -424,6 +440,9 @@ def elaborate(self, platform): m = Module() m.submodules.merge_manager = merge_manager + for elem in method_map.methods_and_transactions: + elem._set_method_uses(m) + for transaction in self.transactions: ready = [ method_map.readiness_by_method_and_transaction[transaction, method] @@ -610,6 +629,118 @@ def __setitem__(self, name: str, value): return self.__setattr__(name, value) +class EnterType(Enum): + """Characterizes stack behavior of Amaranth's context managers for control structures.""" + + #: Used for `m.If`, `m.Switch` and `m.FSM`. + PUSH = auto() + #: Used for `m.Elif` and `m.Else`. + ADD = auto() + #: Used for `m.Case`, `m.Default` and `m.State`. + ENTRY = auto() + + +@dataclass(frozen=True) +class PathEdge: + """Describes an edge in Amaranth's control tree. + + Attributes + ---------- + alt : int + Which alternative (e.g. case of `m.If` or m.Switch`) is described. + par : int + Which parallel control structure (e.g. `m.If` at the same level) is described. + """ + + alt: int = 0 + par: int = 0 + + +@dataclass +class CtrlPath: + """Describes a path in Amaranth's control tree. + + Attributes + ---------- + module : int + Unique number of the module the path refers to. + path : list[PathEdge] + Path in the control tree, starting from the root. + """ + + module: int + path: list[PathEdge] + + def exclusive_with(self, other: "CtrlPath"): + """Decides if this path is mutually exclusive with some other path. + + Paths are mutually exclusive if they refer to the same module and + diverge on different alternatives of the same control structure. + + Arguments + --------- + other : CtrlPath + The other path this path is compared to. + """ + common_prefix = [] + for a, b in zip(self.path, other.path): + if a == b: + common_prefix.append(a) + elif a.par != b.par: + return False + else: + break + + return ( + self.module == other.module + and len(common_prefix) != len(self.path) + and len(common_prefix) != len(other.path) + ) + + +class CtrlPathBuilder: + """Constructs control paths. + + Used internally by `TModule`.""" + + def __init__(self, module: int): + """ + Parameters + ---------- + module: int + Unique module identifier. + """ + self.module = module + self.ctrl_path: list[PathEdge] = [] + self.previous: Optional[PathEdge] = None + + @contextmanager + def enter(self, enter_type=EnterType.PUSH): + et = EnterType + + match enter_type: + case et.ADD: + assert self.previous is not None + self.ctrl_path.append(replace(self.previous, alt=self.previous.alt + 1)) + case et.ENTRY: + self.ctrl_path[-1] = replace(self.ctrl_path[-1], alt=self.ctrl_path[-1].alt + 1) + case et.PUSH: + if self.previous is not None: + self.ctrl_path.append(PathEdge(par=self.previous.par + 1)) + else: + self.ctrl_path.append(PathEdge()) + self.previous = None + try: + yield + finally: + if enter_type in [et.PUSH, et.ADD]: + self.previous = self.ctrl_path.pop() + + def build_ctrl_path(self): + """Returns the current control path.""" + return CtrlPath(self.module, self.ctrl_path[:]) + + class TModule(ModuleLike, Elaboratable): """Extended Amaranth module for use with transactions. @@ -629,6 +760,8 @@ class TModule(ModuleLike, Elaboratable): statements together. """ + __next_uid = 0 + def __init__(self): self.main_module = Module() self.avoiding_module = Module() @@ -637,54 +770,65 @@ def __init__(self): self.submodules = self.main_module.submodules self.domains = self.main_module.domains self.fsm: Optional[FSM] = None + self.uid = TModule.__next_uid + self.path_builder = CtrlPathBuilder(self.uid) + TModule.__next_uid += 1 @contextmanager def AvoidedIf(self, cond: ValueLike): # noqa: N802 with self.main_module.If(cond): - yield + with self.path_builder.enter(EnterType.PUSH): + yield @contextmanager def If(self, cond: ValueLike): # noqa: N802 with self.main_module.If(cond): with self.avoiding_module.If(cond): - yield + with self.path_builder.enter(EnterType.PUSH): + yield @contextmanager def Elif(self, cond): # noqa: N802 with self.main_module.Elif(cond): with self.avoiding_module.Elif(cond): - yield + with self.path_builder.enter(EnterType.ADD): + yield @contextmanager def Else(self): # noqa: N802 with self.main_module.Else(): with self.avoiding_module.Else(): - yield + with self.path_builder.enter(EnterType.ADD): + yield @contextmanager def Switch(self, test: ValueLike): # noqa: N802 with self.main_module.Switch(test): with self.avoiding_module.Switch(test): - yield + with self.path_builder.enter(EnterType.PUSH): + yield @contextmanager def Case(self, *patterns: SwitchKey): # noqa: N802 with self.main_module.Case(*patterns): with self.avoiding_module.Case(*patterns): - yield + with self.path_builder.enter(EnterType.ENTRY): + yield @contextmanager def Default(self): # noqa: N802 with self.main_module.Default(): with self.avoiding_module.Default(): - yield + with self.path_builder.enter(EnterType.ENTRY): + yield @contextmanager def FSM(self, reset: Optional[str] = None, domain: str = "sync", name: str = "fsm"): # noqa: N802 old_fsm = self.fsm with self.main_module.FSM(reset, domain, name) as fsm: self.fsm = fsm - yield fsm + with self.path_builder.enter(EnterType.PUSH): + yield fsm self.fsm = old_fsm @contextmanager @@ -692,7 +836,8 @@ def State(self, name: str): # noqa: N802 assert self.fsm is not None with self.main_module.State(name): with self.avoiding_module.If(self.fsm.ongoing(name)): - yield + with self.path_builder.enter(EnterType.ENTRY): + yield @property def next(self) -> NoReturn: @@ -702,6 +847,10 @@ def next(self) -> NoReturn: def next(self, name: str): self.main_module.next = name + @property + def ctrl_path(self): + return self.path_builder.build_ctrl_path() + @property def _MustUse__silence(self): # noqa: N802 return self.main_module._MustUse__silence @@ -726,17 +875,20 @@ class TransactionBase(Owned, Protocol): defined: bool = False name: str src_loc: SrcLoc - method_uses: dict["Method", Tuple[Record, ValueLike]] + method_uses: dict["Method", tuple[Record, Signal]] + method_calls: defaultdict["Method", list[tuple[CtrlPath, Record, ValueLike]]] relations: list[RelationBase] simultaneous_list: list[TransactionOrMethod] independent_list: list[TransactionOrMethod] + ctrl_path: CtrlPath = CtrlPath(-1, []) def __init__(self, *, src_loc: int | SrcLoc): self.src_loc = get_src_loc(src_loc) - self.method_uses: dict["Method", Tuple[Record, ValueLike]] = dict() - self.relations: list[RelationBase] = [] - self.simultaneous_list: list[TransactionOrMethod] = [] - self.independent_list: list[TransactionOrMethod] = [] + self.method_uses = {} + self.method_calls = defaultdict(list) + self.relations = [] + self.simultaneous_list = [] + self.independent_list = [] def add_conflict(self, end: TransactionOrMethod, priority: Priority = Priority.UNDEFINED) -> None: """Registers a conflict. @@ -773,11 +925,6 @@ def schedule_before(self, end: TransactionOrMethod) -> None: RelationBase(end=end, priority=Priority.LEFT, conflict=False, silence_warning=self.owner != end.owner) ) - def use_method(self, method: "Method", arg: Record, enable: ValueLike): - if method in self.method_uses: - raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction '{self.name}'") - self.method_uses[method] = (arg, enable) - def simultaneous(self, *others: TransactionOrMethod) -> None: """Adds simultaneity relations. @@ -829,6 +976,8 @@ def _independent(self, *others: TransactionOrMethod) -> None: @contextmanager def context(self: TransactionOrMethodBound, m: TModule) -> Iterator[TransactionOrMethodBound]: + self.ctrl_path = m.ctrl_path + parent = TransactionBase.peek() if parent is not None: parent.schedule_before(self) @@ -839,6 +988,20 @@ def context(self: TransactionOrMethodBound, m: TModule) -> Iterator[TransactionO yield self finally: TransactionBase.stack.pop() + self.defined = True + + def _set_method_uses(self, m: ModuleLike): + for method, calls in self.method_calls.items(): + arg_rec, enable_sig = self.method_uses[method] + if len(calls) == 1: + m.d.comb += arg_rec.eq(calls[0][1]) + m.d.comb += enable_sig.eq(calls[0][2]) + else: + call_ens = Cat([en for _, _, en in calls]) + + for i in OneHotSwitchDynamic(m, call_ens): + m.d.comb += arg_rec.eq(calls[i][1]) + m.d.comb += enable_sig.eq(1) @classmethod def get(cls) -> Self: @@ -957,7 +1120,6 @@ def body(self, m: TModule, *, request: ValueLike = C(1)) -> Iterator["Transactio with self.context(m): with m.AvoidedIf(self.grant): yield self - self.defined = True def __repr__(self) -> str: return "(transaction {})".format(self.name) @@ -1154,14 +1316,11 @@ def body( self.def_order = next(TransactionBase.def_counter) self.validate_arguments = validate_arguments - try: - m.d.av_comb += self.ready.eq(ready) - m.d.top_comb += self.data_out.eq(out) - with self.context(m): - with m.AvoidedIf(self.run): - yield self.data_in - finally: - self.defined = True + m.d.av_comb += self.ready.eq(ready) + m.d.top_comb += self.data_out.eq(out) + with self.context(m): + with m.AvoidedIf(self.run): + yield self.data_in def _validate_arguments(self, arg_rec: Record) -> ValueLike: if self.validate_arguments is not None: @@ -1228,7 +1387,16 @@ def __call__( enable_sig = Signal(name=self.owned_name + "_enable") m.d.av_comb += enable_sig.eq(enable) m.d.top_comb += assign(arg_rec, arg, fields=AssignType.ALL) - TransactionBase.get().use_method(self, arg_rec, enable_sig) + + caller = TransactionBase.get() + if not all(ctrl_path.exclusive_with(m.ctrl_path) for ctrl_path, _, _ in caller.method_calls[self]): + raise RuntimeError(f"Method '{self.name}' can't be called twice from the same caller '{caller.name}'") + caller.method_calls[self].append((m.ctrl_path, arg_rec, enable_sig)) + + if self not in caller.method_uses: + arg_rec_use = Record.like(self.data_in) + arg_rec_enable_sig = Signal() + caller.method_uses[self] = (arg_rec_use, arg_rec_enable_sig) return self.data_out