diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index b9af6482b..fc4f73052 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -36,9 +36,6 @@ jobs: # https://github.com/actions/runner/issues/2033 chown -R $(id -u):$(id -g) $PWD - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -114,9 +111,6 @@ jobs: # https://github.com/actions/runner/issues/2033 chown -R $(id -u):$(id -g) $PWD - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -165,9 +159,6 @@ jobs: # https://github.com/actions/runner/issues/2033 chown -R $(id -u):$(id -g) $PWD - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: diff --git a/.github/workflows/deploy_gh_pages.yml b/.github/workflows/deploy_gh_pages.yml index 07dad5c22..eaf35d90a 100644 --- a/.github/workflows/deploy_gh_pages.yml +++ b/.github/workflows/deploy_gh_pages.yml @@ -23,9 +23,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6609cac4d..3d3650c57 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,9 +20,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -153,9 +150,6 @@ jobs: git config --global --add safe.directory /__w/coreblocks/coreblocks git submodule > .gitmodules-hash - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -263,9 +257,6 @@ jobs: git config --global --add safe.directory /__w/coreblocks/coreblocks git submodule > .gitmodules-hash - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -318,9 +309,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: @@ -353,9 +341,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Checkout submodules - run: git submodule update --init --recursive amaranth-stubs - - name: Set up Python uses: actions/setup-python@v5 with: diff --git a/.gitmodules b/.gitmodules index e1e6ec15d..8dea05eb8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,6 +8,3 @@ [submodule "test/external/riscof/riscv-arch-test"] path = test/external/riscof/riscv-arch-test url = https://github.com/riscv-non-isa/riscv-arch-test.git -[submodule "amaranth-stubs"] - path = amaranth-stubs - url = https://github.com/kuznia-rdzeni/amaranth-stubs.git diff --git a/README.md b/README.md index dcd7fb056..5a3b4d0ac 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Coreblocks is an experimental, modular out-of-order [RISC-V](https://riscv.org/s * Simplicity. Coreblocks is an academic project, accessible to students. It should be suitable for teaching essentials of out-of-order architectures. * Modularity. We want to be able to easily experiment with the core by adding, replacing and modifying modules without changing the source too much. - For this goal, we designed a [transaction system](https://kuznia-rdzeni.github.io/coreblocks/Transactions.html) inspired by [Bluespec](http://wiki.bluespec.com/). + For this goal, we designed a transaction system called [Transactron](https://github.com/kuznia-rdzeni/transactron), which is inspired by [Bluespec](http://wiki.bluespec.com/). * Fine-grained testing. Outside of the integration tests for the full core, modules are tested individually. This is to support an agile style of development. @@ -25,9 +25,6 @@ The core currently supports the full RV32I instruction set and several extension Exceptions and some of machine-mode CSRs are supported, the support for interrupts is currently rudimentary and incompatible with the RISC-V spec. Coreblocks can be used with [LiteX](https://github.com/enjoy-digital/litex) (currently using a [patched version](https://github.com/kuznia-rdzeni/litex/tree/coreblocks)). -The transaction system we use as the foundation for the core is well-tested and usable. -We plan to make it available as a separate Python package. - ## Documentation The [documentation for our project](https://kuznia-rdzeni.github.io/coreblocks/) is automatically generated using [Sphinx](https://www.sphinx-doc.org/). diff --git a/amaranth-stubs b/amaranth-stubs deleted file mode 160000 index c0325b42e..000000000 --- a/amaranth-stubs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c0325b42e4553def483a82ffed14fdc6bf353bdb diff --git a/coreblocks/arch/isa_consts.py b/coreblocks/arch/isa_consts.py index bca452493..a472239cd 100644 --- a/coreblocks/arch/isa_consts.py +++ b/coreblocks/arch/isa_consts.py @@ -166,6 +166,12 @@ class PrivilegeLevel(IntEnum, shape=2): MACHINE = 0b11 +@unique +class TrapVectorMode(IntEnum, shape=2): + DIRECT = 0b00 + VECTORED = 0b01 + + @unique class InterruptCauseNumber(IntEnum): SSI = 1 # supervisor software interrupt diff --git a/coreblocks/backend/retirement.py b/coreblocks/backend/retirement.py index c5bd39d64..20f704276 100644 --- a/coreblocks/backend/retirement.py +++ b/coreblocks/backend/retirement.py @@ -10,6 +10,7 @@ from coreblocks.arch import ExceptionCause from coreblocks.interface.keys import CoreStateKey, CSRInstancesKey, InstructionPrecommitKey from coreblocks.priv.csr.csr_instances import CSRAddress, DoubleCounterCSR +from coreblocks.arch.isa_consts import TrapVectorMode class Retirement(Elaboratable): @@ -55,11 +56,16 @@ def __init__( max_latency=2 * 2**gen_params.rob_entries_bits, ) + layouts = self.gen_params.get(RetirementLayouts) self.dependency_manager = DependencyContext.get() self.core_state = Method(o=self.gen_params.get(RetirementLayouts).core_state, nonexclusive=True) self.dependency_manager.add_dependency(CoreStateKey(), self.core_state) - self.precommit = Method(o=self.gen_params.get(RetirementLayouts).precommit, nonexclusive=True) + # The argument is only used in argument validation, it is not needed in the method body. + # A dummy combiner is provided. + self.precommit = Method( + i=layouts.precommit_in, o=layouts.precommit_out, nonexclusive=True, combiner=lambda m, args, runs: 0 + ) self.dependency_manager.add_dependency(InstructionPrecommitKey(), self.precommit) def elaborate(self, platform): @@ -208,8 +214,17 @@ def flush_instr(rob_entry): self.perf_trap_latency.stop(m) handler_pc = Signal(self.gen_params.isa.xlen) - # mtvec without mode is [mxlen-1:2], mode is two last bits. Only direct mode is supported - m.d.av_comb += handler_pc.eq(m_csr.mtvec.read(m).data & ~(0b11)) + mtvec_offset = Signal(self.gen_params.isa.xlen) + mtvec_base = m_csr.mtvec_base.read(m).data + mtvec_mode = m_csr.mtvec_mode.read(m).data + mcause = m_csr.mcause.read(m).data + + # When mode is Vectored, interrupts set pc to base + 4 * cause_number + with m.If(mcause[-1] & (mtvec_mode == TrapVectorMode.VECTORED)): + m.d.av_comb += mtvec_offset.eq(mcause << 2) + + # (mtvec_base stores base[MXLEN-1:2]) + m.d.av_comb += handler_pc.eq((mtvec_base << 2) + mtvec_offset) resume_pc = Mux(continue_pc_override, continue_pc, handler_pc) m.d.sync += continue_pc_override.eq(0) @@ -228,9 +243,11 @@ def flush_instr(rob_entry): def _(): return {"flushing": core_flushing} - @def_method(m, self.precommit) - def _(): - rob_entry = self.rob_peek(m) - return {"rob_id": rob_entry.rob_id, "side_fx": side_fx} + rob_id_val = Signal(self.gen_params.rob_entries_bits) + + @def_method(m, self.precommit, validate_arguments=lambda rob_id: rob_id == rob_id_val) + def _(rob_id): + m.d.top_comb += rob_id_val.eq(self.rob_peek(m).rob_id) + return {"side_fx": side_fx} return m diff --git a/coreblocks/core.py b/coreblocks/core.py index 1a107622c..207821d24 100644 --- a/coreblocks/core.py +++ b/coreblocks/core.py @@ -1,11 +1,11 @@ from amaranth import * -from amaranth.lib.wiring import Component, flipped, connect, Out +from amaranth.lib.wiring import Component, flipped, connect, In, Out from transactron.utils.amaranth_ext.elaboratables import ModuleConnector from transactron.utils.dependencies import DependencyContext from coreblocks.priv.traps.instr_counter import CoreInstructionCounter from coreblocks.func_blocks.interface.func_blocks_unifier import FuncBlocksUnifier -from coreblocks.priv.traps.interrupt_controller import InternalInterruptController +from coreblocks.priv.traps.interrupt_controller import ISA_RESERVED_INTERRUPTS, InternalInterruptController from transactron.core import Transaction, TModule from transactron.lib import ConnectTrans, MethodProduct from coreblocks.interface.layouts import * @@ -35,12 +35,14 @@ class Core(Component): wb_instr: WishboneInterface wb_data: WishboneInterface + interrupts: Signal def __init__(self, *, gen_params: GenParams): super().__init__( { "wb_instr": Out(WishboneSignature(gen_params.wb_params)), "wb_data": Out(WishboneSignature(gen_params.wb_params)), + "interrupts": In(ISA_RESERVED_INTERRUPTS + gen_params.interrupt_custom_count), } ) @@ -96,8 +98,8 @@ def __init__(self, *, gen_params: GenParams): def elaborate(self, platform): m = TModule() - connect(m, flipped(self.wb_instr), self.wb_master_instr.wb_master) - connect(m, flipped(self.wb_data), self.wb_master_data.wb_master) + connect(m.top_module, flipped(self.wb_instr), self.wb_master_instr.wb_master) + connect(m.top_module, flipped(self.wb_data), self.wb_master_data.wb_master) m.submodules.wb_master_instr = self.wb_master_instr m.submodules.wb_master_data = self.wb_master_data @@ -115,6 +117,8 @@ def elaborate(self, platform): m.submodules.csr_generic = self.csr_generic m.submodules.interrupt_controller = self.interrupt_controller + m.d.comb += self.interrupt_controller.internal_report_level.eq(self.interrupts[0:16]) + m.d.comb += self.interrupt_controller.custom_report.eq(self.interrupts[16:]) m.submodules.core_counter = core_counter = CoreInstructionCounter(self.gen_params) @@ -136,10 +140,12 @@ def elaborate(self, platform): m.submodules.exception_information_register = self.exception_information_register - fetch_resume_fb, fetch_resume_unifiers = self.connections.get_dependency(FetchResumeKey()) - m.submodules.fetch_resume_unifiers = ModuleConnector(**fetch_resume_unifiers) + fetch_resume = self.connections.get_optional_dependency(FetchResumeKey()) + if fetch_resume is not None: + fetch_resume_fb, fetch_resume_unifiers = fetch_resume + m.submodules.fetch_resume_unifiers = ModuleConnector(**fetch_resume_unifiers) - m.submodules.fetch_resume_connector = ConnectTrans(fetch_resume_fb, self.frontend.resume_from_unsafe) + m.submodules.fetch_resume_connector = ConnectTrans(fetch_resume_fb, self.frontend.resume_from_unsafe) m.submodules.announcement = self.announcement m.submodules.func_blocks_unifier = self.func_blocks_unifier diff --git a/coreblocks/core_structs/rf.py b/coreblocks/core_structs/rf.py index cb2f32ccd..6865219b0 100644 --- a/coreblocks/core_structs/rf.py +++ b/coreblocks/core_structs/rf.py @@ -16,7 +16,7 @@ def __init__(self, *, gen_params: GenParams): self.internal_layout = make_layout(("reg_val", gen_params.isa.xlen), ("valid", 1)) self.read_layout = layouts.rf_read_out self.entries = Array( - Signal(self.internal_layout, reset={"reg_val": 0, "valid": k == 0}) + Signal(self.internal_layout, init={"reg_val": 0, "valid": k == 0}) for k in range(2**gen_params.phys_regs_bits) ) diff --git a/coreblocks/core_structs/rob.py b/coreblocks/core_structs/rob.py index 72a3b291d..c0cc4ac13 100644 --- a/coreblocks/core_structs/rob.py +++ b/coreblocks/core_structs/rob.py @@ -1,5 +1,4 @@ from amaranth import * -from amaranth.lib.data import View import amaranth.lib.memory as memory from transactron import Method, Transaction, def_method, TModule from transactron.lib.metrics import * @@ -19,7 +18,7 @@ def __init__(self, gen_params: GenParams) -> None: self.retire = Method() self.done = Array(Signal() for _ in range(2**self.params.rob_entries_bits)) self.exception = Array(Signal() for _ in range(2**self.params.rob_entries_bits)) - self.data = memory.Memory(shape=layouts.data_layout.size, depth=2**self.params.rob_entries_bits, init=[]) + self.data = memory.Memory(shape=layouts.data_layout, depth=2**self.params.rob_entries_bits, init=[]) self.get_indices = Method(o=layouts.get_indices, nonexclusive=True) self.perf_rob_wait_time = FIFOLatencyMeasurer( @@ -54,8 +53,8 @@ def elaborate(self, platform): @def_method(m, self.peek, ready=peek_possible) def _(): - return { # remove View after Amaranth upgrade - "rob_data": View(self.params.get(ROBLayouts).data_layout, read_port.data), + return { + "rob_data": read_port.data, "rob_id": start_idx, "exception": self.exception[start_idx], } diff --git a/coreblocks/frontend/fetch/fetch.py b/coreblocks/frontend/fetch/fetch.py index efe1b39d0..d26906e52 100644 --- a/coreblocks/frontend/fetch/fetch.py +++ b/coreblocks/frontend/fetch/fetch.py @@ -1,6 +1,5 @@ from amaranth import * from amaranth.lib.data import ArrayLayout -from amaranth.lib.coding import PriorityEncoder from coreblocks.interface.keys import FetchResumeKey from transactron.lib import BasicFifo, Semaphore, ConnectTrans, logging, Pipe from transactron.lib.metrics import * @@ -8,6 +7,7 @@ from transactron.utils import MethodLayout, popcount, assign from transactron.utils.dependencies import DependencyContext from transactron.utils.transactron_helpers import from_method_layout, make_layout +from transactron.utils.amaranth_ext.coding import PriorityEncoder from transactron import * from coreblocks.cache.iface import CacheInterface @@ -402,7 +402,13 @@ def _(): if self.gen_params.extra_verification: expect_unstall_unsafe = Signal() prev_stalled_unsafe = Signal() - unifier_ready = DependencyContext.get().get_dependency(FetchResumeKey())[0].ready + dependencies = DependencyContext.get() + fetch_resume = dependencies.get_optional_dependency(FetchResumeKey()) + if fetch_resume is not None: + unifier_ready = fetch_resume[0].ready + else: + unifier_ready = C(0) + m.d.sync += prev_stalled_unsafe.eq(stalled_unsafe) with m.FSM("running"): with m.State("running"): diff --git a/coreblocks/frontend/frontend.py b/coreblocks/frontend/frontend.py index f36a4fc17..f4212d342 100644 --- a/coreblocks/frontend/frontend.py +++ b/coreblocks/frontend/frontend.py @@ -62,7 +62,7 @@ def __init__(self, *, gen_params: GenParams, instr_bus: BusMasterInterface): def elaborate(self, platform): m = TModule() - if self.icache_refiller: + if self.gen_params.icache_params.enable: m.submodules.icache_refiller = self.icache_refiller m.submodules.icache = self.icache diff --git a/coreblocks/func_blocks/csr/csr.py b/coreblocks/func_blocks/csr/csr.py index b59ae18f8..a633a647f 100644 --- a/coreblocks/func_blocks/csr/csr.py +++ b/coreblocks/func_blocks/csr/csr.py @@ -101,13 +101,12 @@ def elaborate(self, platform): done = Signal() call_resume = Signal() exception = Signal() - precommitting = Signal() current_result = Signal(self.gen_params.isa.xlen) instr = Signal(StructLayout(self.csr_layouts.rs.data_layout.members | {"valid": 1})) - m.d.comb += ready_to_process.eq(precommitting & instr.valid & (instr.rp_s1 == 0)) + m.d.comb += ready_to_process.eq(instr.valid & (instr.rp_s1 == 0)) # RISCV Zicsr spec Table 1.1 should_read_csr = Signal() @@ -134,6 +133,9 @@ def elaborate(self, platform): # Methods used within this Tranaction are CSRRegister internal _fu_(read|write) handlers which are always ready with Transaction().body(m, request=(ready_to_process & ~done)): + precommit = self.dependency_manager.get_dependency(InstructionPrecommitKey()) + info = precommit(m, instr.rob_id) + m.d.top_comb += exe_side_fx.eq(info.side_fx) with m.Switch(instr.csr): for csr_number, methods in self.regfile.items(): read, write = methods @@ -257,14 +259,6 @@ def _(): # CSR instructions are never compressed, PC+4 is always next instruction return {"pc": instr.pc + self.gen_params.isa.ilen_bytes} - # Generate precommitting signal from precommit - with Transaction().body(m): - precommit = self.dependency_manager.get_dependency(InstructionPrecommitKey()) - info = precommit(m) - with m.If(instr.rob_id == info.rob_id): - m.d.comb += precommitting.eq(1) - m.d.comb += exe_side_fx.eq(info.side_fx) - return m diff --git a/coreblocks/func_blocks/fu/common/rs.py b/coreblocks/func_blocks/fu/common/rs.py index 16afaf9c2..3c39045ad 100644 --- a/coreblocks/func_blocks/fu/common/rs.py +++ b/coreblocks/func_blocks/fu/common/rs.py @@ -2,7 +2,6 @@ from collections.abc import Iterable from typing import Optional from amaranth import * -from amaranth.lib.coding import PriorityEncoder from transactron import Method, Transaction, def_method, TModule from coreblocks.params import GenParams from coreblocks.arch import OpType @@ -11,6 +10,7 @@ from transactron.utils import RecordDict from transactron.utils import assign from transactron.utils.assign import AssignType +from transactron.utils.amaranth_ext.coding import PriorityEncoder from transactron.utils.amaranth_ext.functions import popcount from transactron.utils.transactron_helpers import make_layout diff --git a/test/transactron/__init__.py b/coreblocks/func_blocks/fu/fpu/__init__.py similarity index 100% rename from test/transactron/__init__.py rename to coreblocks/func_blocks/fu/fpu/__init__.py diff --git a/coreblocks/func_blocks/fu/fpu/fpu_common.py b/coreblocks/func_blocks/fu/fpu/fpu_common.py new file mode 100644 index 000000000..14ad02739 --- /dev/null +++ b/coreblocks/func_blocks/fu/fpu/fpu_common.py @@ -0,0 +1,38 @@ +from amaranth.lib import enum + + +class RoundingModes(enum.Enum): + ROUND_UP = 3 + ROUND_DOWN = 2 + ROUND_ZERO = 1 + ROUND_NEAREST_EVEN = 0 + ROUND_NEAREST_AWAY = 4 + + +class Errors(enum.IntFlag): + INVALID_OPERATION = enum.auto() + DIVISION_BY_ZERO = enum.auto() + OVERFLOW = enum.auto() + UNDERFLOW = enum.auto() + INEXACT = enum.auto() + + +class FPUParams: + """FPU parameters + + Parameters + ---------- + sig_width: int + Width of significand, including implicit bit + exp_width: int + Width of exponent + """ + + def __init__( + self, + *, + sig_width: int = 24, + exp_width: int = 8, + ): + self.sig_width = sig_width + self.exp_width = exp_width diff --git a/coreblocks/func_blocks/fu/fpu/fpu_error_module.py b/coreblocks/func_blocks/fu/fpu/fpu_error_module.py new file mode 100644 index 000000000..5759f34f5 --- /dev/null +++ b/coreblocks/func_blocks/fu/fpu/fpu_error_module.py @@ -0,0 +1,176 @@ +from amaranth import * +from transactron import TModule, Method, def_method +from coreblocks.func_blocks.fu.fpu.fpu_common import ( + RoundingModes, + FPUParams, + Errors, +) + + +class FPUErrorMethodLayout: + """FPU error checking module layouts for methods + + Parameters + ---------- + fpu_params: FPUParams + FPU parameters + """ + + def __init__(self, *, fpu_params: FPUParams): + """ + input_inf is a flag that comes from previous stage. + Its purpose is to indicate that the infinity on input + is a result of infinity arithmetic and not a result of overflow + """ + self.error_in_layout = [ + ("sign", 1), + ("sig", fpu_params.sig_width), + ("exp", fpu_params.exp_width), + ("rounding_mode", RoundingModes), + ("inexact", 1), + ("invalid_operation", 1), + ("division_by_zero", 1), + ("input_inf", 1), + ] + self.error_out_layout = [ + ("sign", 1), + ("sig", fpu_params.sig_width), + ("exp", fpu_params.exp_width), + ("errors", Errors), + ] + + +class FPUErrorModule(Elaboratable): + """FPU error checking module + + Parameters + ---------- + fpu_params: FPUParams + FPU rounding module parameters + + Attributes + ---------- + error_checking_request: Method + Transactional method for initiating error checking of a floating point number. + Takes 'error_in_layout' as argument + Returns final number and errors as 'error_out_layout' + """ + + def __init__(self, *, fpu_params: FPUParams): + + self.fpu_errors_params = fpu_params + self.method_layouts = FPUErrorMethodLayout(fpu_params=self.fpu_errors_params) + self.error_checking_request = Method( + i=self.method_layouts.error_in_layout, + o=self.method_layouts.error_out_layout, + ) + + def elaborate(self, platform): + m = TModule() + + max_exp = C( + 2 ** (self.fpu_errors_params.exp_width) - 1, + unsigned(self.fpu_errors_params.exp_width), + ) + max_normal_exp = C( + 2 ** (self.fpu_errors_params.exp_width) - 2, + unsigned(self.fpu_errors_params.exp_width), + ) + max_sig = C( + 2 ** (self.fpu_errors_params.sig_width) - 1, + unsigned(self.fpu_errors_params.sig_width), + ) + + overflow = Signal() + underflow = Signal() + inexact = Signal() + tininess = Signal() + + final_exp = Signal(self.fpu_errors_params.exp_width) + final_sig = Signal(self.fpu_errors_params.sig_width) + final_sign = Signal() + final_errors = Signal(5) + + @def_method(m, self.error_checking_request) + def _(arg): + is_nan = arg.invalid_operation | ((arg.exp == max_exp) & (arg.sig.any())) + is_inf = arg.division_by_zero | arg.input_inf + input_not_special = ~(is_nan) & ~(is_inf) + m.d.av_comb += overflow.eq(input_not_special & (arg.exp == max_exp)) + m.d.av_comb += tininess.eq((arg.exp == 0) & (~arg.sig[-1])) + m.d.av_comb += inexact.eq(overflow | (input_not_special & arg.inexact)) + m.d.av_comb += underflow.eq(tininess & inexact) + + with m.If(is_nan | is_inf): + + m.d.av_comb += final_exp.eq(arg.exp) + m.d.av_comb += final_sig.eq(arg.sig) + m.d.av_comb += final_sign.eq(arg.sign) + + with m.Elif(overflow): + + with m.Switch(arg.rounding_mode): + with m.Case(RoundingModes.ROUND_NEAREST_AWAY, RoundingModes.ROUND_NEAREST_EVEN): + + m.d.av_comb += final_exp.eq(max_exp) + m.d.av_comb += final_sig.eq(0) + m.d.av_comb += final_sign.eq(arg.sign) + + with m.Case(RoundingModes.ROUND_ZERO): + + m.d.av_comb += final_exp.eq(max_normal_exp) + m.d.av_comb += final_sig.eq(max_sig) + m.d.av_comb += final_sign.eq(arg.sign) + + with m.Case(RoundingModes.ROUND_DOWN): + + with m.If(arg.sign): + + m.d.av_comb += final_exp.eq(max_exp) + m.d.av_comb += final_sig.eq(0) + m.d.av_comb += final_sign.eq(arg.sign) + + with m.Else(): + + m.d.av_comb += final_exp.eq(max_normal_exp) + m.d.av_comb += final_sig.eq(max_sig) + m.d.av_comb += final_sign.eq(arg.sign) + + with m.Case(RoundingModes.ROUND_UP): + + with m.If(arg.sign): + + m.d.av_comb += final_exp.eq(max_normal_exp) + m.d.av_comb += final_sig.eq(max_sig) + m.d.av_comb += final_sign.eq(arg.sign) + + with m.Else(): + + m.d.av_comb += final_exp.eq(max_exp) + m.d.av_comb += final_sig.eq(0) + m.d.av_comb += final_sign.eq(arg.sign) + + with m.Else(): + with m.If((arg.exp == 0) & (arg.sig[-1] == 1)): + m.d.av_comb += final_exp.eq(1) + with m.Else(): + m.d.av_comb += final_exp.eq(arg.exp) + m.d.av_comb += final_sig.eq(arg.sig) + m.d.av_comb += final_sign.eq(arg.sign) + + m.d.av_comb += final_errors.eq( + Mux(arg.invalid_operation, Errors.INVALID_OPERATION, 0) + | Mux(arg.division_by_zero, Errors.DIVISION_BY_ZERO, 0) + | Mux(overflow, Errors.OVERFLOW, 0) + | Mux(underflow, Errors.UNDERFLOW, 0) + | Mux(inexact, Errors.INEXACT, 0) + ) + + return { + "exp": final_exp, + "sig": final_sig, + "sign": final_sign, + "errors": final_errors, + } + + return m diff --git a/coreblocks/func_blocks/fu/fpu/fpu_rounding_module.py b/coreblocks/func_blocks/fu/fpu/fpu_rounding_module.py new file mode 100644 index 000000000..267d8557d --- /dev/null +++ b/coreblocks/func_blocks/fu/fpu/fpu_rounding_module.py @@ -0,0 +1,117 @@ +from amaranth import * +from transactron import TModule, Method, def_method +from coreblocks.func_blocks.fu.fpu.fpu_common import ( + RoundingModes, + FPUParams, +) + + +class FPURoudningMethodLayout: + """FPU Rounding module layouts for methods + + Parameters + ---------- + fpu_params: FPUParams + FPU parameters + """ + + def __init__(self, *, fpu_params: FPUParams): + self.rounding_in_layout = [ + ("sign", 1), + ("sig", fpu_params.sig_width), + ("exp", fpu_params.exp_width), + ("round_bit", 1), + ("sticky_bit", 1), + ("rounding_mode", RoundingModes), + ] + self.rounding_out_layout = [ + ("sig", fpu_params.sig_width), + ("exp", fpu_params.exp_width), + ("inexact", 1), + ] + + +class FPURounding(Elaboratable): + """FPU Rounding module + + Parameters + ---------- + fpu_params: FPUParams + FPU parameters + + Attributes + ---------- + rounding_request: Method + Transactional method for initiating rounding of a floating point number. + Takes 'rounding_in_layout' as argument + Returns rounded number and errors as 'rounding_out_layout' + """ + + def __init__(self, *, fpu_params: FPUParams): + + self.fpu_rounding_params = fpu_params + self.method_layouts = FPURoudningMethodLayout(fpu_params=self.fpu_rounding_params) + self.rounding_request = Method( + i=self.method_layouts.rounding_in_layout, + o=self.method_layouts.rounding_out_layout, + ) + + def elaborate(self, platform): + m = TModule() + + add_one = Signal() + inc_rtnte = Signal() + inc_rtnta = Signal() + inc_rtpi = Signal() + inc_rtmi = Signal() + + rounded_sig = Signal(self.fpu_rounding_params.sig_width + 1) + normalised_sig = Signal(self.fpu_rounding_params.sig_width) + rounded_exp = Signal(self.fpu_rounding_params.exp_width) + + final_round_bit = Signal() + final_sticky_bit = Signal() + + inexact = Signal() + + @def_method(m, self.rounding_request) + def _(arg): + + m.d.av_comb += inc_rtnte.eq( + (arg.rounding_mode == RoundingModes.ROUND_NEAREST_EVEN) + & (arg.round_bit & (arg.sticky_bit | arg.sig[0])) + ) + m.d.av_comb += inc_rtnta.eq((arg.rounding_mode == RoundingModes.ROUND_NEAREST_AWAY) & (arg.round_bit)) + m.d.av_comb += inc_rtpi.eq( + (arg.rounding_mode == RoundingModes.ROUND_UP) & (~arg.sign & (arg.round_bit | arg.sticky_bit)) + ) + m.d.av_comb += inc_rtmi.eq( + (arg.rounding_mode == RoundingModes.ROUND_DOWN) & (arg.sign & (arg.round_bit | arg.sticky_bit)) + ) + + m.d.av_comb += add_one.eq(inc_rtmi | inc_rtnta | inc_rtnte | inc_rtpi) + + m.d.av_comb += rounded_sig.eq(arg.sig + add_one) + + with m.If(rounded_sig[-1]): + + m.d.av_comb += normalised_sig.eq(rounded_sig >> 1) + m.d.av_comb += final_round_bit.eq(rounded_sig[0]) + m.d.av_comb += final_sticky_bit.eq(arg.round_bit | arg.sticky_bit) + m.d.av_comb += rounded_exp.eq(arg.exp + 1) + + with m.Else(): + m.d.av_comb += normalised_sig.eq(rounded_sig) + m.d.av_comb += final_round_bit.eq(arg.round_bit) + m.d.av_comb += final_sticky_bit.eq(arg.sticky_bit) + m.d.av_comb += rounded_exp.eq(arg.exp) + + m.d.av_comb += inexact.eq(final_round_bit | final_sticky_bit) + + return { + "exp": rounded_exp, + "sig": normalised_sig, + "inexact": inexact, + } + + return m diff --git a/coreblocks/func_blocks/fu/lsu/dummyLsu.py b/coreblocks/func_blocks/fu/lsu/dummyLsu.py index cef1daa4e..343ad8a6e 100644 --- a/coreblocks/func_blocks/fu/lsu/dummyLsu.py +++ b/coreblocks/func_blocks/fu/lsu/dummyLsu.py @@ -15,7 +15,7 @@ from coreblocks.func_blocks.interface.func_protocols import FuncUnit from coreblocks.func_blocks.fu.lsu.pma import PMAChecker from coreblocks.func_blocks.fu.lsu.lsu_requester import LSURequester -from coreblocks.interface.keys import ExceptionReportKey, CommonBusDataKey, InstructionPrecommitKey +from coreblocks.interface.keys import CoreStateKey, ExceptionReportKey, CommonBusDataKey, InstructionPrecommitKey __all__ = ["LSUDummy", "LSUComponent"] @@ -55,6 +55,11 @@ def elaborate(self, platform): m = TModule() flush = Signal() # exception handling, requests are not issued + with Transaction().body(m): + core_state = self.dependency_manager.get_dependency(CoreStateKey()) + state = core_state(m) + m.d.comb += flush.eq(state.flushing) + # Signals for handling issue logic request_rob_id = Signal(self.gen_params.rob_entries_bits) rob_id_match = Signal() @@ -142,11 +147,8 @@ def _(): with Transaction().body(m): precommit = self.dependency_manager.get_dependency(InstructionPrecommitKey()) - info = precommit(m) - with m.If(info.rob_id == request_rob_id): - m.d.comb += rob_id_match.eq(1) - with m.If(~info.side_fx): - m.d.comb += flush.eq(1) + precommit(m, request_rob_id) + m.d.comb += rob_id_match.eq(1) return m diff --git a/coreblocks/func_blocks/fu/priv.py b/coreblocks/func_blocks/fu/priv.py index 5a2b13742..3d861fbdb 100644 --- a/coreblocks/func_blocks/fu/priv.py +++ b/coreblocks/func_blocks/fu/priv.py @@ -101,32 +101,31 @@ def _(arg): with Transaction().body(m, request=instr_valid & ~finished): precommit = self.dm.get_dependency(InstructionPrecommitKey()) - info = precommit(m) - with m.If(info.rob_id == instr_rob): - m.d.sync += finished.eq(1) - self.perf_instr.incr(m, instr_fn, cond=info.side_fx) - - priv_data = priv_mode.read(m).data - - illegal_mret = (instr_fn == PrivilegedFn.Fn.MRET) & (priv_data != PrivilegeLevel.MACHINE) - # future todo: WFI should be illegal in U-Mode only if S-Mode is supported - illegal_wfi = ( - (instr_fn == PrivilegedFn.Fn.WFI) - & (priv_data == PrivilegeLevel.USER) - & csr.m_mode.mstatus_tw.read(m).data - ) - - with condition(m, nonblocking=True) as branch: - with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.MRET) & ~illegal_mret): - mret(m) - with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.FENCEI)): - flush_icache(m) - with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.WFI) & ~illegal_wfi): - # async_interrupt_active implies wfi_resume. WFI should continue normal execution - # when interrupt is enabled in xie, but disabled via global mstatus.xIE - m.d.sync += finished.eq(wfi_resume) - - m.d.sync += illegal_instruction.eq(illegal_wfi | illegal_mret) + info = precommit(m, instr_rob) + m.d.sync += finished.eq(1) + self.perf_instr.incr(m, instr_fn, cond=info.side_fx) + + priv_data = priv_mode.read(m).data + + illegal_mret = (instr_fn == PrivilegedFn.Fn.MRET) & (priv_data != PrivilegeLevel.MACHINE) + # future todo: WFI should be illegal in U-Mode only if S-Mode is supported + illegal_wfi = ( + (instr_fn == PrivilegedFn.Fn.WFI) + & (priv_data == PrivilegeLevel.USER) + & csr.m_mode.mstatus_tw.read(m).data + ) + + with condition(m, nonblocking=True) as branch: + with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.MRET) & ~illegal_mret): + mret(m) + with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.FENCEI)): + flush_icache(m) + with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.WFI) & ~illegal_wfi): + # async_interrupt_active implies wfi_resume. WFI should continue normal execution + # when interrupt is enabled in xie, but disabled via global mstatus.xIE + m.d.sync += finished.eq(wfi_resume) + + m.d.sync += illegal_instruction.eq(illegal_wfi | illegal_mret) @def_method(m, self.accept, ready=instr_valid & finished) def _(): diff --git a/coreblocks/interface/layouts.py b/coreblocks/interface/layouts.py index a19ecca45..048db5085 100644 --- a/coreblocks/interface/layouts.py +++ b/coreblocks/interface/layouts.py @@ -355,7 +355,9 @@ class RetirementLayouts: def __init__(self, gen_params: GenParams): fields = gen_params.get(CommonLayoutFields) - self.precommit = make_layout(fields.rob_id, fields.side_fx) + self.precommit_in = make_layout(fields.rob_id) + + self.precommit_out = make_layout(fields.side_fx) self.flushing = ("flushing", 1) """ Core is currently flushed """ @@ -581,10 +583,6 @@ class LSULayouts: def __init__(self, gen_params: GenParams): fields = gen_params.get(CommonLayoutFields) - retirement = gen_params.get(RetirementLayouts) - - self.precommit = retirement.precommit - self.store: LayoutListField = ("store", 1) self.issue = make_layout(fields.addr, fields.data, fields.funct3, self.store) @@ -638,10 +636,6 @@ def __init__(self, gen_params: GenParams): self.rs = gen_params.get(RSInterfaceLayouts, rs_entries_bits=self.rs_entries_bits, data_layout=data_layout) - retirement = gen_params.get(RetirementLayouts) - - self.precommit = retirement.precommit - class ExceptionRegisterLayouts: """Layouts used in the exception information register.""" diff --git a/coreblocks/params/configurations.py b/coreblocks/params/configurations.py index 24162058f..c7727e335 100644 --- a/coreblocks/params/configurations.py +++ b/coreblocks/params/configurations.py @@ -32,6 +32,13 @@ [ALUComponent(), ShiftUnitComponent(), JumpComponent(), ExceptionUnitComponent(), PrivilegedUnitComponent()], rs_entries=4, ), + RSBlockComponent( + [ + MulComponent(mul_unit_type=MulType.SEQUENCE_MUL), + DivComponent(), + ], + rs_entries=2, + ), RSBlockComponent([LSUComponent()], rs_entries=2, rs_type=FifoRS), CSRBlockComponent(), ) @@ -127,7 +134,7 @@ def __post_init__(self): instr_buffer_size: int = 4 - interrupt_custom_count: int = 0 + interrupt_custom_count: int = 16 interrupt_custom_edge_trig_mask: int = 0 user_mode: bool = True @@ -139,7 +146,9 @@ def __post_init__(self): _implied_extensions: Extension = Extension(0) _generate_test_hardware: bool = False - pma: list[PMARegion] = field(default_factory=list) + pma: list[PMARegion] = field( + default_factory=lambda: [PMARegion(0xE0000000, 0xFFFFFFFF, mmio=True)] + ) # default I/O region used in LiteX coreblocks class CoreConfiguration(_CoreConfigurationDataClass): @@ -155,12 +164,15 @@ def replace(self, **kwargs) -> Self: tiny_core_config = CoreConfiguration( embedded=True, func_units_config=( - RSBlockComponent([ALUComponent(), ShiftUnitComponent(), JumpComponent()], rs_entries=2), + RSBlockComponent( + [ALUComponent(), ShiftUnitComponent(), JumpComponent(), ExceptionUnitComponent()], rs_entries=2 + ), RSBlockComponent([LSUComponent()], rs_entries=2, rs_type=FifoRS), ), phys_regs_bits=basic_core_config.phys_regs_bits - 1, rob_entries_bits=basic_core_config.rob_entries_bits - 1, - allow_partial_extensions=True, # No exception unit + icache_enable=False, + user_mode=False, ) # Core configuration with all supported components diff --git a/coreblocks/priv/csr/csr_instances.py b/coreblocks/priv/csr/csr_instances.py index 3215b42ee..33b80f04d 100644 --- a/coreblocks/priv/csr/csr_instances.py +++ b/coreblocks/priv/csr/csr_instances.py @@ -4,7 +4,7 @@ from coreblocks.arch import CSRAddress from coreblocks.arch.csr_address import MstatusFieldOffsets from coreblocks.arch.isa import Extension -from coreblocks.arch.isa_consts import PrivilegeLevel, XlenEncoding +from coreblocks.arch.isa_consts import PrivilegeLevel, XlenEncoding, TrapVectorMode from coreblocks.params.genparams import GenParams from coreblocks.priv.csr.csr_register import CSRRegister from coreblocks.priv.csr.aliased import AliasedCSR @@ -75,15 +75,17 @@ def __init__(self, gen_params: GenParams): self.mcause = CSRRegister(CSRAddress.MCAUSE, gen_params) - # SPEC: The mtvec register must always be implemented, but can contain a read-only value. - # set `MODE` as fixed to 0 - Direct mode "All exceptions set pc to BASE" - self.mtvec = CSRRegister(CSRAddress.MTVEC, gen_params, ro_bits=0b11) + self.mtvec = AliasedCSR(CSRAddress.MTVEC, gen_params) mepc_ro_bits = 0b1 if Extension.C in gen_params.isa.extensions else 0b11 # pc alignment (SPEC) self.mepc = CSRRegister(CSRAddress.MEPC, gen_params, ro_bits=mepc_ro_bits) self.mtval = CSRRegister(CSRAddress.MTVAL, gen_params) + self.misa = CSRRegister( + CSRAddress.MISA, gen_params, init=self._misa_value(gen_params), ro_bits=(1 << gen_params.isa.xlen) - 1 + ) + self.priv_mode = CSRRegister( None, gen_params, @@ -94,7 +96,8 @@ def __init__(self, gen_params: GenParams): self.priv_mode_public = AliasedCSR(CSRAddress.COREBLOCKS_TEST_PRIV_MODE, gen_params) self.priv_mode_public.add_field(0, self.priv_mode) - self.mstatus_fields_implementation(gen_params, self.mstatus, self.mstatush) + self._mstatus_fields_implementation(gen_params, self.mstatus, self.mstatush) + self._mtvec_fields_implementation(gen_params, self.mtvec) def elaborate(self, platform): m = Module() @@ -105,7 +108,20 @@ def elaborate(self, platform): return m - def mstatus_fields_implementation(self, gen_params: GenParams, mstatus: AliasedCSR, mstatush: AliasedCSR): + def _mtvec_fields_implementation(self, gen_params: GenParams, mtvec: AliasedCSR): + def filter_legal_mode(m: TModule, v: Value): + legal = Signal(1) + m.d.av_comb += legal.eq((v == TrapVectorMode.DIRECT) | (v == TrapVectorMode.VECTORED)) + return (legal, v) + + self.mtvec_base = CSRRegister(None, gen_params, width=gen_params.isa.xlen - 2) + mtvec.add_field(TrapVectorMode.as_shape().width, self.mtvec_base) + self.mtvec_mode = CSRRegister( + None, gen_params, width=TrapVectorMode.as_shape().width, fu_write_filtermap=filter_legal_mode + ) + mtvec.add_field(0, self.mtvec_mode) + + def _mstatus_fields_implementation(self, gen_params: GenParams, mstatus: AliasedCSR, mstatush: AliasedCSR): def filter_legal_priv_mode(m: TModule, v: Value): legal = Signal(1) with m.Switch(v): @@ -179,6 +195,35 @@ def filter_legal_priv_mode(m: TModule, v: Value): Extension.V in gen_params.isa.extensions or Extension.F in gen_params.isa.extensions, ) + def _misa_value(self, gen_params): + misa_value = 0 + + misa_extension_bits = { + 0: Extension.A, + 1: Extension.B, + 2: Extension.C, + 3: Extension.D, + 4: Extension.E, + 5: Extension.F, + 8: Extension.I, + 12: Extension.M, + 16: Extension.Q, + 21: Extension.V, + } + + for bit, extension in misa_extension_bits.items(): + if extension in gen_params.isa.extensions: + misa_value |= 1 << bit + + if gen_params.user_mode: + misa_value |= 1 << 20 + # 7 - Hypervisor, 18 - Supervisor, 23 - Custom Extensions + + xml_field_mapping = {32: XlenEncoding.W32, 64: XlenEncoding.W64, 128: XlenEncoding.W128} + misa_value |= xml_field_mapping[gen_params.isa.xlen] << (gen_params.isa.xlen - 2) + + return misa_value + class GenericCSRRegisters(Elaboratable): def __init__(self, gen_params: GenParams): diff --git a/docs/api.md b/docs/api.md index 5daa246b7..226f38e51 100644 --- a/docs/api.md +++ b/docs/api.md @@ -2,5 +2,4 @@ ```{eval-rst} .. include:: modules-coreblocks.rst -.. include:: modules-transactron.rst ``` diff --git a/docs/index.md b/docs/index.md index 0e16a25ec..6a9b5afba 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,7 +8,6 @@ maxdepth: 3 home.md assumptions.md development-environment.md -transactions.md scheduler/overview.md shared-structs/implementation/rs-impl.md shared-structs/rs.md diff --git a/docs/transactions.md b/docs/transactions.md deleted file mode 100644 index 41b5d5528..000000000 --- a/docs/transactions.md +++ /dev/null @@ -1,336 +0,0 @@ -# Documentation for Coreblocks transaction framework - -## Introduction - -Coreblocks utilizes a transaction framework for modularizing the design. -It is inspired by the [Bluespec](http://bluespec.com/) programming language (see: [Bluespec wiki](http://wiki.bluespec.com/), [Bluespec compiler](https://github.com/B-Lang-org/bsc)). - -The basic idea is to interface hardware modules using _transactions_ and _methods_. -A transaction is a state-changing operation performed by the hardware in a single clock cycle. -Transactions are atomic: in a given clock cycle, a transaction either executes in its entriety, or not at all. -A transaction is executed only if it is ready for execution and it does not _conflict_ with another transaction scheduled for execution in the same clock cycle. - -A transaction defined in a given hardware module can depend on other hardware modules via the use of methods. -A method can be _called_ by a transaction or by other methods. -Execution of methods is directly linked to the execution of transactions: a method only executes if some transaction which calls the method (directly or indirectly, via other methods) is executed. -If multiple transactions try to call the same method in the same clock cycle, the transactions conflict, and only one of them is executed. -In this way, access to methods is coordinated via the transaction system to avoid conflicts. - -Methods can communicate with their callers in both directions: from caller to method and back. -The communication is structured using Amaranth records. - -## Basic usage - -### Implementing transactions - -The simplest way to implement a transaction as a part of Amaranth `Elaboratable` is by using a `with` block: - -```python -class MyThing(Elaboratable): - ... - - def elaborate(self, platform): - m = TModule() - - ... - - with Transaction().body(m): - # Operations conditioned on the transaction executing. - # Including Amaranth assignments, like: - - m.d.comb += sig1.eq(expr1) - m.d.sync += sig2.eq(expr2) - - # Method calls can also be used, like: - - result = self.method(m, arg_expr) - - ... - - return m -``` - -The transaction body `with` block works analogously to Amaranth's `with m.If():` blocks: the Amaranth assignments and method calls only "work" in clock cycles when the transaction is executed. -This is implemented in hardware via multiplexers. -Please remember that this is not a Python `if` statement -- the *Python code* inside the `with` block is always executed once. - -### Implementing methods - -As methods are used as a way to communicate with other `Elaboratable`s, they are typically declared in the `Elaboratable`'s constructor, and then defined in the `elaborate` method: - -```python -class MyOtherThing(Elaboratable): - def __init__(self): - ... - - # Declaration of the method. - # The i/o parameters pass the format of method argument/result as Amaranth layouts. - # Both parameters are optional. - - self.my_method = Method(i=input_layout, o=output_layout) - - ... - - def elaborate(self, platform): - # A TModule needs to be used instead of an Amaranth module - - m = TModule() - - ... - - @def_method(m, self.my_method) - def _(arg): - # Operations conditioned on the method executing. - # Including Amaranth assignments, like: - - m.d.comb += sig1.eq(expr1) - m.d.sync += sig2.eq(expr2) - - # Method calls can also be used, like: - - result = self.other_method(m, arg_expr) - - # Method result should be returned: - - return ret_expr - - ... - - return m -``` - -The `def_method` technique presented above is a convenience syntax, but it works just like other Amaranth `with` blocks. -In particular, the *Python code* inside the unnamed `def` function is always executed once. - -A method defined in one `Elaboratable` is usually passed to other `Elaboratable`s via constructor parameters. -For example, the `MyThing` constructor could be defined as follows. -Only methods should be passed around, not entire `Elaboratable`s! - -```python -class MyThing(Elaboratable): - def __init__(self, method: Method): - self.method = method - - ... - - ... -``` - -### Method or transaction? - -Sometimes, there might be two alternative ways to implement some functionality: - -* Using a transaction, which calls methods on other `Elaboratable`s. -* Using a method, which is called from other `Elaboratable`s. - -Deciding on a best method is not always easy. -An important question to ask yourself is -- is this functionality something that runs independently from other things (not in lock-step)? -If so, maybe it should be a transaction. -Or is it something that is dependent on some external condition? -If so, maybe it should be a method. - -If in doubt, methods are preferred. -This is because if a functionality is implemented as a method, and a transaction is needed, one can use a transaction which calls this method and does nothing else. -Such a transaction is included in the library -- it's named `AdapterTrans`. - -### Method argument passing conventions - -Even though method arguments are Amaranth records, their use can be avoided in many cases, which results in cleaner code. -Suppose we have the following layout, which is an input layout for a method called `method`: - -```python -layout = [("foo", 1), ("bar", 32)] -method = Method(input_layout=layout) -``` - -The method can be called in multiple ways. -The cleanest and recommended way is to pass each record field using a keyword argument: - -```python -method(m, foo=foo_expr, bar=bar_expr) -``` - -Another way is to pass the arguments using a `dict`: - -```python -method(m, {'foo': foo_expr, 'bar': bar_expr}) -``` - -Finally, one can directly pass an Amaranth record: - -```python -rec = Record(layout) -m.d.comb += rec.foo.eq(foo_expr) -m.d.comb += rec.bar.eq(bar_expr) -method(m, rec) -``` - -The `dict` convention can be used recursively when layouts are nested. -Take the following definitions: - -```python -layout2 = [("foobar", layout), ("baz", 42)] -method2 = Method(input_layout=layout2) -``` - -One can then pass the arguments using `dict`s in following ways: - -```python -# the preferred way -method2(m, foobar={'foo': foo_expr, 'bar': bar_expr}, baz=baz_expr) - -# the alternative way -method2(m, {'foobar': {'foo': foo_expr, 'bar': bar_expr}, 'baz': baz_expr}) -``` - -### Method definition conventions - -When defining methods, two conventions can be used. -The cleanest and recommended way is to create an argument for each record field: - -```python -@def_method(m, method) -def _(foo: Value, bar: Value): - ... -``` - -The other is to receive the argument record directly. The `arg` name is required: - -```python -def_method(m, method) -def _(arg: Record): - ... -``` - -### Method return value conventions - -The `dict` syntax can be used for returning values from methods. -Take the following method declaration: - -```python -method3 = Method(input_layout=layout, output_layout=layout2) -``` - -One can then define this method as follows: - -```python -@def_method(m, method3) -def _(foo: Value, bar: Value): - return {{'foo': foo, 'bar': foo + bar}, 'baz': foo - bar} -``` - -### Readiness signals - -If a transaction is not always ready for execution (for example, because of the dependence on some resource), a `request` parameter should be used. -An Amaranth single-bit expression should be passed. -When the `request` parameter is not passed, the transaction is always requesting execution. - -```python - with Transaction().body(m, request=expr): -``` - -Methods have a similar mechanism, which uses the `ready` parameter on `def_method`: - -```python - @def_method(m, self.my_method, ready=expr) - def _(arg): - ... -``` - -The `request` signal typically should only depend on the internal state of an `Elaboratable`. -Other dependencies risk introducing combinational loops. -In certain occasions, it is possible to relax this requirement; see e.g. [Scheduling order](#scheduling-order). - -## The library - -The transaction framework is designed to facilitate code re-use. -It includes a library, which contains `Elaboratable`s providing useful methods and transactions. -The most useful ones are: - -* `ConnectTrans`, for connecting two methods together with a transaction. -* `FIFO`, for queues accessed with two methods, `read` and `write`. -* `Adapter` and `AdapterTrans`, for communicating with transactions and methods from plain Amaranth code. - These are very useful in testbenches. - -## Advanced concepts - -### Special combinational domains - -Transactron defines its own variant of Amaranth modules, called `TModule`. -Its role is to allow to improve circuit performance by omitting unneeded multiplexers in combinational circuits. -This is done by adding two additional, special combinatorial domains, `av_comb` and `top_comb`. - -Statements added to the `av_comb` domain (the "avoiding" domain) are not executed when under a false `m.If`, but are executed when under a false `m.AvoidedIf`. -Transaction and method bodies are internally guarded by an `m.AvoidedIf` with the transaction `grant` or method `run` signal. -Therefore combinational assignments added to `av_comb` work even if the transaction or method definition containing the assignments are not running. -Because combinational signals usually don't induce state changes, this is often safe to do and improves performance. - -Statements added to the `top_comb` domain are always executed, even if the statement is under false conditions (including `m.If`, `m.Switch` etc.). -This allows for cleaner code, as combinational assignments which logically belong to some case, but aren't actually required to be there, can be as performant as if they were manually moved to the top level. - -An important caveat of the special domains is that, just like with normal domains, a signal assigned in one of them cannot be assigned in others. - -### Scheduling order - -When writing multiple methods and transactions in the same `Elaboratable`, sometimes some dependency between them needs to exist. -For example, in the `Forwarder` module in the library, forwarding can take place only if both `read` and `write` are executed simultaneously. -This requirement is handled by making the the `read` method's readiness depend on the execution of the `write` method. -If the `read` method was considered for execution before `write`, this would introduce a combinational loop into the circuit. -In order to avoid such issues, one can require a certain scheduling order between methods and transactions. - -`Method` and `Transaction` objects include a `schedule_before` method. -Its only argument is another `Method` or `Transaction`, which will be scheduled after the first one: - -```python -first_t_or_m.schedule_before(other_t_or_m) -``` - -Internally, scheduling orders exist only on transactions. -If a scheduling order is added to a `Method`, it is lifted to the transaction level. -For example, if `first_m` is scheduled before `other_t`, and is called by `t1` and `t2`, the added scheduling orderings will be the same as if the following calls were made: - -```python -t1.schedule_before(other_t) -t2.schedule_before(other_t) -``` - -### Conflicts - -In some situations it might be useful to make some methods or transactions mutually exclusive with others. -Two conflicting transactions or methods can't execute simultaneously: only one or the other runs in a given clock cycle. - -Conflicts are defined similarly to scheduling orders: - -```python -first_t_or_m.add_conflict(other_t_or_m) -``` - -Conflicts are lifted to the transaction level, just like scheduling orders. - -The `add_conflict` method has an optional argument `priority`, which allows to define a scheduling order between conflicting transactions or methods. -Possible values are `Priority.LEFT`, `Priority.RIGHT` and `Priority.UNDEFINED` (the default). -For example, the following code adds a conflict with a scheduling order, where `first_m` is scheduled before `other_m`: - -```python -first_m.add_conflict(other_m, priority = Priority.LEFT) -``` - -Scheduling conflicts come with a possible cost. -The conflicting transactions have a dependency in the transaction scheduler, which can increase the size and combinational delay of the scheduling circuit. -Therefore, use of this feature requires consideration. - -### Transaction and method nesting - -Transaction and method bodies can be nested. For example: - -```python -with Transaction().body(m): - # Transaction body. - - with Transaction().body(m): - # Nested transaction body. -``` - -Nested transactions and methods can only run if the parent also runs. -The converse is not true: it is possible that only the parent runs, but the nested transaction or method doesn't (because of other limitations). -Nesting implies scheduling order: the nested transaction or method is considered for execution after the parent. diff --git a/pytest.ini b/pytest.ini index 142b00abe..c2bc22f2b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,5 +4,6 @@ testpaths = tests norecursedirs = '*.egg', '.*', 'build', 'dist', 'venv', '__traces__', '__pycache__' filterwarnings = + ignore:cannot collect test class 'TestbenchContext':pytest.PytestCollectionWarning ignore:cannot collect test class 'TestbenchIO':pytest.PytestCollectionWarning ignore:No files were found in testpaths:pytest.PytestConfigWarning: diff --git a/requirements.txt b/requirements.txt index 7d70b711d..19b874c8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -./amaranth-stubs/ # can't use -e -- pyright doesn't see the stubs then :( +amaranth-stubs @ git+https://github.com/kuznia-rdzeni/amaranth-stubs.git@edb302b001433edf4c8568190adc9bd0c0039f45 +transactron @ git+https://github.com/kuznia-rdzeni/transactron.git@972047b7bfac3d2e193a428de35c976f9b17c51a amaranth-yosys==0.40.0.0.post100 -git+https://github.com/amaranth-lang/amaranth@9bd536bbf96b07720d6e4a8709b30492af8ddd13 +amaranth==0.5.3 dataclasses-json==0.6.3 diff --git a/scripts/build_docs.sh b/scripts/build_docs.sh index 6f58a5a6b..40e56ba89 100755 --- a/scripts/build_docs.sh +++ b/scripts/build_docs.sh @@ -60,5 +60,4 @@ $ROOT_PATH/scripts/core_graph.py -p -f mermaid $DOCS_DIR/auto_graph.rst sed -i -e '1i\.. mermaid::\n' -e 's/^/ /' $DOCS_DIR/auto_graph.rst sphinx-apidoc --tocfile modules-coreblocks -o $DOCS_DIR $ROOT_PATH/coreblocks/ -sphinx-apidoc --tocfile modules-transactron -o $DOCS_DIR $ROOT_PATH/transactron/ sphinx-build -b html -W $DOCS_DIR $BUILD_DIR diff --git a/test/asm/interrupt_vectored.asm b/test/asm/interrupt_vectored.asm new file mode 100644 index 000000000..19c05042f --- /dev/null +++ b/test/asm/interrupt_vectored.asm @@ -0,0 +1,136 @@ +.include "init_regs.s" + +_start: + INIT_REGS_LOAD + + # fibonacci spiced with interrupt handler (also with fibonacci) + li x1, 0x201 + csrw mtvec, x1 + li x1, 0x203 + csrw mtvec, x1 + csrr x16, mtvec # since mtvec is WARL, should stay 0x201 + ecall # synchronous exception jumps to 0x200 + 0x0 + +interrupts: + li x27, 0 # handler count + li x30, 0 # interrupt count + li x31, 0xde # branch guard + + csrsi mstatus, 0x8 # machine interrupt enable + csrr x29, mstatus + li x1, 0x30000 + csrw mie, x1 # enable custom interrupt 0 and 1 + li x1, 0 + li x2, 1 + li x5, 4 + li x6, 7 + li x7, 0 + li x12, 4 + li x13, 7 + li x14, 0 +loop: + add x3, x2, x1 + mv x1, x2 + mv x2, x3 + bne x2, x4, loop +infloop: + j infloop + +int0_handler: + # save main loop register state + mv x9, x1 + mv x10, x2 + mv x11, x3 + + # check cause + li x2, 0x80000010 # cause for 01,11 + csrr x3, mcause + bne x2, x3, fail + + # fibonacci step + beq x7, x8, skip + add x7, x6, x5 + mv x5, x6 + mv x6, x7 + +skip: + # generate new mie mask + andi x2, x30, 0x3 + bnez x2, fill_skip + li x2, 0x3 + fill_skip: + slli x2, x2, 16 + csrw mie, x2 + + # clear interrupts + csrr x1, mip + srli x1, x1, 16 + andi x2, x1, 0x1 + beqz x2, skip_clear_edge + addi x30, x30, 1 + li x2, 0x10000 + csrc mip, x2 # clear edge reported interrupt + skip_clear_edge: + andi x2, x1, 0x2 + beqz x2, skip_clear_level + addi x30, x30, 1 + csrwi 0x7ff, 1 # clear level reported interrupt via custom csr + skip_clear_level: + addi x27, x27, 1 + + # restore main loop register state + mv x1, x9 + mv x2, x10 + mv x3, x11 + mret + +int1_handler: + # save main loop register state + mv x9, x1 + mv x10, x2 + mv x11, x3 + + # check cause + li x2, 0x80000011 # cause for 10 + csrr x3, mcause + bne x2, x3, fail + + # fibonacci step + beq x14, x15, skip + add x14, x13, x12 + mv x12, x13 + mv x13, x14 + j skip + +ecall_handler: + li x17, 0x111 + la x1, interrupts + csrw mepc, x1 + mret + +fail: + csrwi 0x7ff, 2 + j fail + +.org 0x200 + j ecall_handler + nop + nop + nop + nop + nop + nop + nop + nop + nop + nop + nop + nop + nop + nop + j fail + j int0_handler + j int1_handler + li x31, 0xae # should never happen + +INIT_REGS_ALLOCATION diff --git a/test/backend/test_annoucement.py b/test/backend/test_annoucement.py index e6fd56fa3..71b83a319 100644 --- a/test/backend/test_annoucement.py +++ b/test/backend/test_annoucement.py @@ -8,7 +8,7 @@ from coreblocks.interface.layouts import * from coreblocks.params import GenParams from coreblocks.params.configurations import test_core_config -from transactron.testing import TestCaseWithSimulator, TestbenchIO +from transactron.testing import TestCaseWithSimulator, TestbenchIO, TestbenchContext class BackendTestCircuit(Elaboratable): @@ -104,32 +104,33 @@ def generate_producer(self, i: int): results to its output FIFO. This records will be next serialized by FUArbiter. """ - def producer(): + async def producer(sim: TestbenchContext): inputs = self.fu_inputs[i] for rob_id, result, rp_dst in inputs: io: TestbenchIO = self.m.fu_fifo_ins[i] - yield from io.call_init(rob_id=rob_id, result=result, rp_dst=rp_dst) - yield from self.random_wait(self.max_wait) + io.call_init(sim, rob_id=rob_id, result=result, rp_dst=rp_dst) + await self.random_wait(sim, self.max_wait) self.producer_end[i] = True return producer - def consumer(self): - yield from self.m.rs_announce_val_tbio.enable() - yield from self.m.rob_mark_done_tbio.enable() + async def consumer(self, sim: TestbenchContext): + # TODO: this test doesn't do anything, fix it! + self.m.rs_announce_val_tbio.enable(sim) + self.m.rob_mark_done_tbio.enable(sim) while reduce(and_, self.producer_end, True): # All 3 methods (in RF, RS and ROB) need to be enabled for the result # announcement transaction to take place. We want to have at least one # method disabled most of the time, so that the transaction is performed # only when we enable it inside the loop. Otherwise the transaction could # get executed at any time, particularly when we wouldn't be monitoring it - yield from self.m.rf_announce_val_tbio.enable() + self.m.rf_announce_val_tbio.enable(sim) - rf_result = yield from self.m.rf_announce_val_tbio.method_argument() - rs_result = yield from self.m.rs_announce_val_tbio.method_argument() - rob_result = yield from self.m.rob_mark_done_tbio.method_argument() + rf_result = self.m.rf_announce_val_tbio.get_outputs(sim) + rs_result = self.m.rs_announce_val_tbio.get_outputs(sim) + rob_result = self.m.rob_mark_done_tbio.get_outputs(sim) - yield from self.m.rf_announce_val_tbio.disable() + self.m.rf_announce_val_tbio.disable(sim) assert rf_result is not None assert rs_result is not None @@ -144,20 +145,20 @@ def consumer(self): del self.expected_output[t] else: self.expected_output[t] -= 1 - yield from self.random_wait(self.max_wait) + await self.random_wait(sim, self.max_wait) def test_one_out(self): self.fu_count = 1 self.initialize() with self.run_simulation(self.m) as sim: - sim.add_process(self.consumer) + sim.add_testbench(self.consumer) for i in range(self.fu_count): - sim.add_process(self.generate_producer(i)) + sim.add_testbench(self.generate_producer(i)) def test_many_out(self): self.fu_count = 4 self.initialize() with self.run_simulation(self.m) as sim: - sim.add_process(self.consumer) + sim.add_testbench(self.consumer) for i in range(self.fu_count): - sim.add_process(self.generate_producer(i)) + sim.add_testbench(self.generate_producer(i)) diff --git a/test/backend/test_retirement.py b/test/backend/test_retirement.py index bc50290c9..cf039ed13 100644 --- a/test/backend/test_retirement.py +++ b/test/backend/test_retirement.py @@ -12,6 +12,7 @@ from coreblocks.params import GenParams from coreblocks.interface.layouts import ROBLayouts, RFLayouts, SchedulerLayouts from coreblocks.params.configurations import test_core_config +from transactron.lib.adapters import AdapterTrans from transactron.testing import * from collections import deque @@ -120,45 +121,46 @@ def setup_method(self): # (and the retirement code doesn't have any special behaviour to handle these cases), but in this simple # test we don't care to make sure that the randomly generated inputs are correct in this way. - @def_method_mock(lambda self: self.retc.mock_rob_retire, enable=lambda self: bool(self.submit_q), sched_prio=1) + @def_method_mock(lambda self: self.retc.mock_rob_retire, enable=lambda self: bool(self.submit_q)) def retire_process(self): - self.submit_q.popleft() + @MethodMock.effect + def eff(): + self.submit_q.popleft() @def_method_mock(lambda self: self.retc.mock_rob_peek, enable=lambda self: bool(self.submit_q)) def peek_process(self): return self.submit_q[0] - def free_reg_process(self): + async def free_reg_process(self, sim: TestbenchContext): while self.rf_exp_q: - reg = yield from self.retc.free_rf_adapter.call() + reg = await self.retc.free_rf_adapter.call(sim) assert reg["reg_id"] == self.rf_exp_q.popleft() - def rat_process(self): + async def rat_process(self, sim: TestbenchContext): while self.rat_map_q: current_map = self.rat_map_q.popleft() wait_cycles = 0 # this test waits for next rat pair to be correctly set and will timeout if that assignment fails - while (yield self.retc.rat.entries[current_map["rl_dst"]]) != current_map["rp_dst"]: + while sim.get(self.retc.rat.entries[current_map["rl_dst"]]) != current_map["rp_dst"]: wait_cycles += 1 if wait_cycles >= self.cycles + 10: assert False, "RAT entry was not updated" - yield Tick() + await sim.tick() assert not self.submit_q assert not self.rf_free_q - def precommit_process(self): - yield from self.retc.precommit_adapter.call_init() + async def precommit_process(self, sim: TestbenchContext): while self.precommit_q: - yield Tick() - info = yield from self.retc.precommit_adapter.call_result() + info = await self.retc.precommit_adapter.call_try(sim, rob_id=self.precommit_q[0]) assert info is not None assert info["side_fx"] - assert self.precommit_q[0] == info["rob_id"] self.precommit_q.popleft() - @def_method_mock(lambda self: self.retc.mock_rf_free, sched_prio=2) + @def_method_mock(lambda self: self.retc.mock_rf_free) def rf_free_process(self, reg_id): - assert reg_id == self.rf_free_q.popleft() + @MethodMock.effect + def eff(): + assert reg_id == self.rf_free_q.popleft() @def_method_mock(lambda self: self.retc.mock_exception_cause) def exception_cause_process(self): @@ -177,7 +179,7 @@ def mock_trap_entry_process(self): pass @def_method_mock(lambda self: self.retc.mock_fetch_continue) - def mock_fetch_continue_process(self): + def mock_fetch_continue_process(self, pc): pass @def_method_mock(lambda self: self.retc.mock_async_interrupt_cause) @@ -187,6 +189,6 @@ def mock_async_interrupt_cause(self): def test_rand(self): self.retc = RetirementTestCircuit(self.gen_params) with self.run_simulation(self.retc) as sim: - sim.add_process(self.free_reg_process) - sim.add_process(self.rat_process) - sim.add_process(self.precommit_process) + sim.add_testbench(self.free_reg_process) + sim.add_testbench(self.rat_process) + sim.add_testbench(self.precommit_process) diff --git a/test/cache/test_icache.py b/test/cache/test_icache.py index a52d75f35..88d44450a 100644 --- a/test/cache/test_icache.py +++ b/test/cache/test_icache.py @@ -3,7 +3,6 @@ import random from amaranth import Elaboratable, Module -from amaranth.sim import Passive, Settle, Tick from amaranth.utils import exact_log2 from transactron.lib import AdapterTrans, Adapter @@ -15,7 +14,10 @@ from coreblocks.params.configurations import test_core_config from coreblocks.cache.refiller import SimpleCommonBusCacheRefiller -from transactron.testing import TestCaseWithSimulator, TestbenchIO, def_method_mock, RecordIntDictRet +from transactron.testing import TestCaseWithSimulator, TestbenchIO, def_method_mock, TestbenchContext +from transactron.testing.functions import MethodData +from transactron.testing.method_mock import MethodMock +from transactron.testing.testbenchio import CallTrigger from ..peripherals.test_wishbone import WishboneInterfaceWrapper @@ -98,35 +100,29 @@ def setup_method(self) -> None: self.bad_addresses.add(bad_addr) self.bad_fetch_blocks.add(bad_addr & ~(self.cp.fetch_block_bytes - 1)) - def wishbone_slave(self): - yield Passive() - + async def wishbone_slave(self, sim: TestbenchContext): while True: - yield from self.test_module.wb_ctrl.slave_wait() + adr, *_ = await self.test_module.wb_ctrl.slave_wait(sim) # Wishbone is addressing words, so we need to shift it a bit to get the real address. - addr = (yield self.test_module.wb_ctrl.wb.adr) << exact_log2(self.cp.word_width_bytes) + addr = adr << exact_log2(self.cp.word_width_bytes) - yield Tick() - while random.random() < 0.5: - yield Tick() + await self.random_wait_geom(sim, 0.5) err = 1 if addr in self.bad_addresses else 0 data = random.randrange(2**self.gen_params.isa.xlen) self.mem[addr] = data - yield from self.test_module.wb_ctrl.slave_respond(data, err=err) - - yield Settle() + await self.test_module.wb_ctrl.slave_respond(sim, data, err=err) - def refiller_process(self): + async def refiller_process(self, sim: TestbenchContext): while self.requests: req_addr = self.requests.pop() - yield from self.test_module.start_refill.call(addr=req_addr) + await self.test_module.start_refill.call(sim, addr=req_addr) for i in range(self.cp.fetch_blocks_in_line): - ret = yield from self.test_module.accept_refill.call() + ret = await self.test_module.accept_refill.call(sim) cur_addr = req_addr + i * self.cp.fetch_block_bytes @@ -149,8 +145,8 @@ def refiller_process(self): def test(self): with self.run_simulation(self.test_module) as sim: - sim.add_process(self.wishbone_slave) - sim.add_process(self.refiller_process) + sim.add_testbench(self.wishbone_slave, background=True) + sim.add_testbench(self.refiller_process) class ICacheBypassTestCircuit(Elaboratable): @@ -220,17 +216,14 @@ def load_or_gen_mem(self, addr: int): self.mem[addr] = random.randrange(2**self.gen_params.isa.ilen) return self.mem[addr] - def wishbone_slave(self): - yield Passive() - + async def wishbone_slave(self, sim: TestbenchContext): while True: - yield from self.m.wb_ctrl.slave_wait() + adr, *_ = await self.m.wb_ctrl.slave_wait(sim) # Wishbone is addressing words, so we need to shift it a bit to get the real address. - addr = (yield self.m.wb_ctrl.wb.adr) << exact_log2(self.cp.word_width_bytes) + addr = adr << exact_log2(self.cp.word_width_bytes) - while random.random() < 0.5: - yield Tick() + await self.random_wait_geom(sim, 0.5) err = 1 if addr in self.bad_addrs else 0 @@ -238,19 +231,16 @@ def wishbone_slave(self): if self.gen_params.isa.xlen == 64: data = self.load_or_gen_mem(addr + 4) << 32 | data - yield from self.m.wb_ctrl.slave_respond(data, err=err) + await self.m.wb_ctrl.slave_respond(sim, data, err=err) - yield Settle() - - def user_process(self): + async def user_process(self, sim: TestbenchContext): while self.requests: req_addr = self.requests.popleft() & ~(self.cp.fetch_block_bytes - 1) - yield from self.m.issue_req.call(addr=req_addr) + await self.m.issue_req.call(sim, addr=req_addr) - while random.random() < 0.5: - yield Tick() + await self.random_wait_geom(sim, 0.5) - ret = yield from self.m.accept_res.call() + ret = await self.m.accept_res.call(sim) if (req_addr & ~(self.cp.word_width_bytes - 1)) in self.bad_addrs: assert ret["error"] @@ -262,13 +252,12 @@ def user_process(self): data |= self.mem[req_addr + 4] << 32 assert ret["fetch_block"] == data - while random.random() < 0.5: - yield Tick() + await self.random_wait_geom(sim, 0.5) def test(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.wishbone_slave) - sim.add_process(self.user_process) + sim.add_testbench(self.wishbone_slave, background=True) + sim.add_testbench(self.user_process) class MockedCacheRefiller(Elaboratable, CacheRefillerInterface): @@ -328,6 +317,7 @@ def setup_method(self) -> None: self.bad_addrs = set() self.bad_cache_lines = set() self.refill_requests = deque() + self.refill_block_cnt = 0 self.issued_requests = deque() self.accept_refill_request = True @@ -351,12 +341,17 @@ def init_module(self, ways, sets) -> None: @def_method_mock(lambda self: self.m.refiller.start_refill_mock, enable=lambda self: self.accept_refill_request) def start_refill_mock(self, addr): - self.refill_requests.append(addr) - self.refill_block_cnt = 0 - self.refill_in_fly = True - self.refill_addr = addr + @MethodMock.effect + def eff(): + self.refill_requests.append(addr) + self.refill_block_cnt = 0 + self.refill_in_fly = True + self.refill_addr = addr + + def enen(self): + return self.refill_in_fly - @def_method_mock(lambda self: self.m.refiller.accept_refill_mock, enable=lambda self: self.refill_in_fly) + @def_method_mock(lambda self: self.m.refiller.accept_refill_mock, enable=enen) def accept_refill_mock(self): addr = self.refill_addr + self.refill_block_cnt * self.cp.fetch_block_bytes @@ -367,12 +362,14 @@ def accept_refill_mock(self): if addr + i in self.bad_addrs: bad_addr = True - self.refill_block_cnt += 1 + last = self.refill_block_cnt + 1 == self.cp.fetch_blocks_in_line or bad_addr - last = self.refill_block_cnt == self.cp.fetch_blocks_in_line or bad_addr + @MethodMock.effect + def eff(): + self.refill_block_cnt += 1 - if last: - self.refill_in_fly = False + if last: + self.refill_in_fly = False return { "addr": addr, @@ -390,18 +387,19 @@ def add_bad_addr(self, addr: int): self.bad_addrs.add(addr) self.bad_cache_lines.add(addr & ~((1 << self.cp.offset_bits) - 1)) - def send_req(self, addr: int): + async def send_req(self, sim: TestbenchContext, addr: int): self.issued_requests.append(addr) - yield from self.m.issue_req.call(addr=addr) + await self.m.issue_req.call(sim, addr=addr) - def expect_resp(self, wait=False): - yield Settle() + async def expect_resp(self, sim: TestbenchContext, wait=False): if wait: - yield from self.m.accept_res.wait_until_done() + *_, resp = await self.m.accept_res.sample_outputs_until_done(sim) + else: + *_, resp = await self.m.accept_res.sample_outputs(sim) - self.assert_resp((yield from self.m.accept_res.get_outputs())) + self.assert_resp(resp) - def assert_resp(self, resp: RecordIntDictRet): + def assert_resp(self, resp: MethodData): addr = self.issued_requests.popleft() & ~(self.cp.fetch_block_bytes - 1) if (addr & ~((1 << self.cp.offset_bits) - 1)) in self.bad_cache_lines: @@ -417,343 +415,321 @@ def assert_resp(self, resp: RecordIntDictRet): def expect_refill(self, addr: int): assert self.refill_requests.popleft() == addr - def call_cache(self, addr: int): - yield from self.send_req(addr) - yield from self.m.accept_res.enable() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() + async def call_cache(self, sim: TestbenchContext, addr: int): + await self.send_req(sim, addr) + self.m.accept_res.enable(sim) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) def test_1_way(self): self.init_module(1, 4) - def cache_user_process(): + async def cache_user_process(sim: TestbenchContext): # The first request should cause a cache miss - yield from self.call_cache(0x00010004) + await self.call_cache(sim, 0x00010004) self.expect_refill(0x00010000) # Accesses to the same cache line shouldn't cause a cache miss for i in range(self.cp.fetch_blocks_in_line): - yield from self.call_cache(0x00010000 + i * self.cp.fetch_block_bytes) + await self.call_cache(sim, 0x00010000 + i * self.cp.fetch_block_bytes) assert len(self.refill_requests) == 0 # Now go beyond the first cache line - yield from self.call_cache(0x00010000 + self.cp.line_size_bytes) + await self.call_cache(sim, 0x00010000 + self.cp.line_size_bytes) self.expect_refill(0x00010000 + self.cp.line_size_bytes) # Trigger cache aliasing - yield from self.call_cache(0x00020000) - yield from self.call_cache(0x00010000) + await self.call_cache(sim, 0x00020000) + await self.call_cache(sim, 0x00010000) self.expect_refill(0x00020000) self.expect_refill(0x00010000) # Fill the whole cache for i in range(0, self.cp.line_size_bytes * self.cp.num_of_sets, 4): - yield from self.call_cache(i) + await self.call_cache(sim, i) for i in range(self.cp.num_of_sets): self.expect_refill(i * self.cp.line_size_bytes) # Now do some accesses within the cached memory for i in range(50): - yield from self.call_cache(random.randrange(0, self.cp.line_size_bytes * self.cp.num_of_sets, 4)) + await self.call_cache(sim, random.randrange(0, self.cp.line_size_bytes * self.cp.num_of_sets, 4)) assert len(self.refill_requests) == 0 with self.run_simulation(self.m) as sim: - sim.add_process(cache_user_process) + sim.add_testbench(cache_user_process) def test_2_way(self): self.init_module(2, 4) - def cache_process(): + async def cache_process(sim: TestbenchContext): # Fill the first set of both ways - yield from self.call_cache(0x00010000) - yield from self.call_cache(0x00020000) + await self.call_cache(sim, 0x00010000) + await self.call_cache(sim, 0x00020000) self.expect_refill(0x00010000) self.expect_refill(0x00020000) # And now both lines should be in the cache - yield from self.call_cache(0x00010004) - yield from self.call_cache(0x00020004) + await self.call_cache(sim, 0x00010004) + await self.call_cache(sim, 0x00020004) assert len(self.refill_requests) == 0 with self.run_simulation(self.m) as sim: - sim.add_process(cache_process) + sim.add_testbench(cache_process) # Tests whether the cache is fully pipelined and the latency between requests and response is exactly one cycle. def test_pipeline(self): self.init_module(2, 4) - def cache_process(): + async def cache_process(sim: TestbenchContext): # Fill the cache for i in range(self.cp.num_of_sets): addr = 0x00010000 + i * self.cp.line_size_bytes - yield from self.call_cache(addr) + await self.call_cache(sim, addr) self.expect_refill(addr) - yield from self.tick(5) + await self.tick(sim, 4) # Create a stream of requests to ensure the pipeline is working - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) for i in range(0, self.cp.num_of_sets * self.cp.line_size_bytes, 4): addr = 0x00010000 + i self.issued_requests.append(addr) # Send the request - yield from self.m.issue_req.call_init(addr=addr) - yield Settle() - assert (yield from self.m.issue_req.done()) + ret = await self.m.issue_req.call_try(sim, addr=addr) + assert ret is not None # After a cycle the response should be ready - yield Tick() - yield from self.expect_resp() - yield from self.m.issue_req.disable() + await self.expect_resp(sim) - yield Tick() - yield from self.m.accept_res.disable() + self.m.accept_res.disable(sim) - yield from self.tick(5) + await self.tick(sim, 4) # Check how the cache handles queuing the requests - yield from self.send_req(addr=0x00010000 + 3 * self.cp.line_size_bytes) - yield from self.send_req(addr=0x00010004) + await self.send_req(sim, addr=0x00010000 + 3 * self.cp.line_size_bytes) + await self.send_req(sim, addr=0x00010004) # Wait a few cycles. There are two requests queued - yield from self.tick(5) + await self.tick(sim, 4) - yield from self.m.accept_res.enable() - yield from self.expect_resp() - yield Tick() - yield from self.expect_resp() - yield from self.send_req(addr=0x0001000C) - yield from self.expect_resp() + self.m.accept_res.enable(sim) + await self.expect_resp( + sim, + ) + await self.expect_resp( + sim, + ) + await self.send_req(sim, addr=0x0001000C) + await self.expect_resp( + sim, + ) - yield Tick() - yield from self.m.accept_res.disable() + self.m.accept_res.disable(sim) - yield from self.tick(5) + await self.tick(sim, 4) # Schedule two requests, the first one causing a cache miss - yield from self.send_req(addr=0x00020000) - yield from self.send_req(addr=0x00010000 + self.cp.line_size_bytes) + await self.send_req(sim, addr=0x00020000) + await self.send_req(sim, addr=0x00010000 + self.cp.line_size_bytes) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.expect_resp() - yield Tick() - yield from self.m.accept_res.disable() + await self.expect_resp(sim, wait=True) + await self.expect_resp( + sim, + ) + self.m.accept_res.disable(sim) - yield from self.tick(3) + await self.tick(sim, 2) # Schedule two requests, the second one causing a cache miss - yield from self.send_req(addr=0x00020004) - yield from self.send_req(addr=0x00030000 + self.cp.line_size_bytes) + await self.send_req(sim, addr=0x00020004) + await self.send_req(sim, addr=0x00030000 + self.cp.line_size_bytes) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp() - yield Tick() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() + await self.expect_resp( + sim, + ) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) - yield from self.tick(3) + await self.tick(sim, 2) # Schedule two requests, both causing a cache miss - yield from self.send_req(addr=0x00040000) - yield from self.send_req(addr=0x00050000 + self.cp.line_size_bytes) + await self.send_req(sim, addr=0x00040000) + await self.send_req(sim, addr=0x00050000 + self.cp.line_size_bytes) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() + await self.expect_resp(sim, wait=True) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) with self.run_simulation(self.m) as sim: - sim.add_process(cache_process) + sim.add_testbench(cache_process) def test_flush(self): self.init_module(2, 4) - def cache_process(): + async def cache_process(sim: TestbenchContext): # Fill the whole cache for s in range(self.cp.num_of_sets): for w in range(self.cp.num_of_ways): addr = w * 0x00010000 + s * self.cp.line_size_bytes - yield from self.call_cache(addr) + await self.call_cache(sim, addr) self.expect_refill(addr) # Everything should be in the cache for s in range(self.cp.num_of_sets): for w in range(self.cp.num_of_ways): addr = w * 0x00010000 + s * self.cp.line_size_bytes - yield from self.call_cache(addr) + await self.call_cache(sim, addr) assert len(self.refill_requests) == 0 - yield from self.m.flush_cache.call() + await self.m.flush_cache.call(sim) # The cache should be empty for s in range(self.cp.num_of_sets): for w in range(self.cp.num_of_ways): addr = w * 0x00010000 + s * self.cp.line_size_bytes - yield from self.call_cache(addr) + await self.call_cache(sim, addr) self.expect_refill(addr) # Try to flush during refilling the line - yield from self.send_req(0x00030000) - yield from self.m.flush_cache.call() + await self.send_req(sim, 0x00030000) + await self.m.flush_cache.call(sim) # We still should be able to accept the response for the last request - self.assert_resp((yield from self.m.accept_res.call())) + self.assert_resp(await self.m.accept_res.call(sim)) self.expect_refill(0x00030000) - yield from self.call_cache(0x00010000) + await self.call_cache(sim, 0x00010000) self.expect_refill(0x00010000) - yield Tick() - # Try to execute issue_req and flush_cache methods at the same time - yield from self.m.issue_req.call_init(addr=0x00010000) self.issued_requests.append(0x00010000) - yield from self.m.flush_cache.call_init() - yield Settle() - assert not (yield from self.m.issue_req.done()) - assert (yield from self.m.flush_cache.done()) - yield Tick() - yield from self.m.flush_cache.call_do() - yield from self.m.issue_req.call_do() - self.assert_resp((yield from self.m.accept_res.call())) + issue_req_res, flush_cache_res = ( + await CallTrigger(sim).call(self.m.issue_req, addr=0x00010000).call(self.m.flush_cache) + ) + assert issue_req_res is None + assert flush_cache_res is not None + await self.m.issue_req.call(sim, addr=0x00010000) + self.assert_resp(await self.m.accept_res.call(sim)) self.expect_refill(0x00010000) - yield Tick() - # Schedule two requests and then flush - yield from self.send_req(0x00000000 + self.cp.line_size_bytes) - yield from self.send_req(0x00010000) + await self.send_req(sim, 0x00000000 + self.cp.line_size_bytes) + await self.send_req(sim, 0x00010000) - yield from self.m.flush_cache.call_init() - yield Tick() + res = await self.m.flush_cache.call_try(sim) # We cannot flush until there are two pending requests - assert not (yield from self.m.flush_cache.done()) - yield Tick() - yield from self.m.flush_cache.disable() - yield Tick() + assert res is None + res = await self.m.flush_cache.call_try(sim) + assert res is None # Accept the first response - self.assert_resp((yield from self.m.accept_res.call())) + self.assert_resp(await self.m.accept_res.call(sim)) - yield from self.m.flush_cache.call() + await self.m.flush_cache.call(sim) # And accept the second response ensuring that we got old data - self.assert_resp((yield from self.m.accept_res.call())) + self.assert_resp(await self.m.accept_res.call(sim)) self.expect_refill(0x00000000 + self.cp.line_size_bytes) # Just make sure that the line is truly flushed - yield from self.call_cache(0x00010000) + await self.call_cache(sim, 0x00010000) self.expect_refill(0x00010000) with self.run_simulation(self.m) as sim: - sim.add_process(cache_process) + sim.add_testbench(cache_process) def test_errors(self): self.init_module(1, 4) - def cache_process(): + async def cache_process(sim: TestbenchContext): self.add_bad_addr(0x00010000) # Bad addr at the beggining of the line self.add_bad_addr(0x00020008) # Bad addr in the middle of the line self.add_bad_addr( 0x00030000 + self.cp.line_size_bytes - self.cp.word_width_bytes ) # Bad addr at the end of the line - yield from self.call_cache(0x00010008) + await self.call_cache(sim, 0x00010008) self.expect_refill(0x00010000) # Requesting a bad addr again should retrigger refill - yield from self.call_cache(0x00010008) + await self.call_cache(sim, 0x00010008) self.expect_refill(0x00010000) - yield from self.call_cache(0x00020000) + await self.call_cache(sim, 0x00020000) self.expect_refill(0x00020000) - yield from self.call_cache(0x00030008) + await self.call_cache(sim, 0x00030008) self.expect_refill(0x00030000) # Test how pipelining works with errors - yield from self.m.accept_res.disable() - yield Tick() + self.m.accept_res.disable(sim) # Schedule two requests, the first one causing an error - yield from self.send_req(addr=0x00020000) - yield from self.send_req(addr=0x00011000) + await self.send_req(sim, addr=0x00020000) + await self.send_req(sim, addr=0x00011000) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() + await self.expect_resp(sim, wait=True) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) - yield from self.tick(3) + await self.tick(sim, 3) # Schedule two requests, the second one causing an error - yield from self.send_req(addr=0x00021004) - yield from self.send_req(addr=0x00030000) + await self.send_req(sim, addr=0x00021004) + await self.send_req(sim, addr=0x00030000) - yield from self.tick(10) + await self.tick(sim, 10) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() + await self.expect_resp(sim, wait=True) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) - yield from self.tick(3) + await self.tick(sim, 3) # Schedule two requests, both causing an error - yield from self.send_req(addr=0x00020000) - yield from self.send_req(addr=0x00010000) + await self.send_req(sim, addr=0x00020000) + await self.send_req(sim, addr=0x00010000) - yield from self.m.accept_res.enable() + self.m.accept_res.enable(sim) - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.expect_resp(wait=True) - yield Tick() - yield from self.m.accept_res.disable() - yield Tick() + await self.expect_resp(sim, wait=True) + await self.expect_resp(sim, wait=True) + self.m.accept_res.disable(sim) # The second request will cause an error - yield from self.send_req(addr=0x00021004) - yield from self.send_req(addr=0x00030000) + await self.send_req(sim, addr=0x00021004) + await self.send_req(sim, addr=0x00030000) - yield from self.tick(10) + await self.tick(sim, 10) # Accept the first response - yield from self.m.accept_res.enable() - yield from self.expect_resp(wait=True) - yield Tick() + self.m.accept_res.enable(sim) + await self.expect_resp(sim, wait=True) # Wait before accepting the second response - yield from self.m.accept_res.disable() - yield from self.tick(10) - yield from self.m.accept_res.enable() - yield from self.expect_resp(wait=True) - - yield Tick() + self.m.accept_res.disable(sim) + await self.tick(sim, 10) + self.m.accept_res.enable(sim) + await self.expect_resp(sim, wait=True) # This request should not cause an error - yield from self.send_req(addr=0x00011000) - yield from self.expect_resp(wait=True) + await self.send_req(sim, addr=0x00011000) + await self.expect_resp(sim, wait=True) with self.run_simulation(self.m) as sim: - sim.add_process(cache_process) + sim.add_testbench(cache_process) def test_random(self): self.init_module(4, 8) @@ -765,34 +741,28 @@ def test_random(self): if random.random() < 0.05: self.add_bad_addr(i) - def refiller_ctrl(): - yield Passive() - + async def refiller_ctrl(sim: TestbenchContext): while True: - yield from self.random_wait_geom(0.4) + await self.random_wait_geom(sim, 0.4) self.accept_refill_request = False - yield from self.random_wait_geom(0.7) + await self.random_wait_geom(sim, 0.7) self.accept_refill_request = True - def sender(): + async def sender(sim: TestbenchContext): for _ in range(iterations): - yield from self.send_req(random.randrange(0, max_addr, 4)) + await self.send_req(sim, random.randrange(0, max_addr, 4)) + await self.random_wait_geom(sim, 0.5) - while random.random() < 0.5: - yield Tick() - - def receiver(): + async def receiver(sim: TestbenchContext): for _ in range(iterations): while len(self.issued_requests) == 0: - yield Tick() - - self.assert_resp((yield from self.m.accept_res.call())) + await sim.tick() - while random.random() < 0.2: - yield Tick() + self.assert_resp(await self.m.accept_res.call(sim)) + await self.random_wait_geom(sim, 0.2) with self.run_simulation(self.m) as sim: - sim.add_process(sender) - sim.add_process(receiver) - sim.add_process(refiller_ctrl) + sim.add_testbench(sender) + sim.add_testbench(receiver) + sim.add_testbench(refiller_ctrl, background=True) diff --git a/test/core_structs/test_rat.py b/test/core_structs/test_rat.py index 01809d677..57093bb97 100644 --- a/test/core_structs/test_rat.py +++ b/test/core_structs/test_rat.py @@ -1,4 +1,4 @@ -from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit, TestbenchContext from coreblocks.core_structs.rat import FRAT, RRAT from coreblocks.params import GenParams @@ -7,6 +7,8 @@ from collections import deque from random import Random +from transactron.testing.testbenchio import CallTrigger + class TestFrontendRegisterAliasTable(TestCaseWithSimulator): def gen_input(self): @@ -18,14 +20,18 @@ def gen_input(self): self.to_execute_list.append({"rl": rl, "rp": rp, "rl_s1": rl_s1, "rl_s2": rl_s2}) - def do_rename(self): + async def do_rename(self, sim: TestbenchContext): for _ in range(self.test_steps): to_execute = self.to_execute_list.pop() - res = yield from self.m.rename.call( - rl_dst=to_execute["rl"], rp_dst=to_execute["rp"], rl_s1=to_execute["rl_s1"], rl_s2=to_execute["rl_s2"] + res = await self.m.rename.call( + sim, + rl_dst=to_execute["rl"], + rp_dst=to_execute["rp"], + rl_s1=to_execute["rl_s1"], + rl_s2=to_execute["rl_s2"], ) - assert res["rp_s1"] == self.expected_entries[to_execute["rl_s1"]] - assert res["rp_s2"] == self.expected_entries[to_execute["rl_s2"]] + assert res.rp_s1 == self.expected_entries[to_execute["rl_s1"]] + assert res.rp_s2 == self.expected_entries[to_execute["rl_s2"]] self.expected_entries[to_execute["rl"]] = to_execute["rp"] @@ -44,7 +50,7 @@ def test_single(self): self.gen_input() with self.run_simulation(m) as sim: - sim.add_process(self.do_rename) + sim.add_testbench(self.do_rename) class TestRetirementRegisterAliasTable(TestCaseWithSimulator): @@ -55,14 +61,17 @@ def gen_input(self): self.to_execute_list.append({"rl": rl, "rp": rp}) - def do_commit(self): + async def do_commit(self, sim: TestbenchContext): for _ in range(self.test_steps): to_execute = self.to_execute_list.pop() - yield from self.m.peek.call_init(rl_dst=to_execute["rl"]) - res = yield from self.m.commit.call(rl_dst=to_execute["rl"], rp_dst=to_execute["rp"]) - peek_res = yield from self.m.peek.call_do() - assert res["old_rp_dst"] == self.expected_entries[to_execute["rl"]] - assert peek_res["old_rp_dst"] == res["old_rp_dst"] + peek_res, res = ( + await CallTrigger(sim) + .call(self.m.peek, rl_dst=to_execute["rl"]) + .call(self.m.commit, rl_dst=to_execute["rl"], rp_dst=to_execute["rp"]) + ) + assert peek_res is not None and res is not None + assert res.old_rp_dst == self.expected_entries[to_execute["rl"]] + assert peek_res.old_rp_dst == res["old_rp_dst"] self.expected_entries[to_execute["rl"]] = to_execute["rp"] @@ -81,4 +90,4 @@ def test_single(self): self.gen_input() with self.run_simulation(m) as sim: - sim.add_process(self.do_commit) + sim.add_testbench(self.do_commit) diff --git a/test/core_structs/test_reorder_buffer.py b/test/core_structs/test_reorder_buffer.py index 0589e7db1..b1f935a81 100644 --- a/test/core_structs/test_reorder_buffer.py +++ b/test/core_structs/test_reorder_buffer.py @@ -1,6 +1,4 @@ -from amaranth.sim import Passive, Settle, Tick - -from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit, TestbenchContext from coreblocks.core_structs.rob import ReorderBuffer from coreblocks.params import GenParams @@ -9,54 +7,50 @@ from queue import Queue from random import Random +from transactron.testing.functions import data_const_to_dict + class TestReorderBuffer(TestCaseWithSimulator): - def gen_input(self): + async def gen_input(self, sim: TestbenchContext): for _ in range(self.test_steps): while self.regs_left_queue.empty(): - yield Tick() + await sim.tick() - while self.rand.random() < 0.5: - yield # to slow down puts + await self.random_wait_geom(sim, 0.5) # to slow down puts log_reg = self.rand.randint(0, self.log_regs - 1) phys_reg = self.regs_left_queue.get() regs = {"rl_dst": log_reg, "rp_dst": phys_reg} - rob_id = yield from self.m.put.call(regs) + rob_id = (await self.m.put.call(sim, regs)).rob_id self.to_execute_list.append((rob_id, phys_reg)) - self.retire_queue.put((regs, rob_id["rob_id"])) + self.retire_queue.put((regs, rob_id)) - def do_updates(self): - yield Passive() + async def do_updates(self, sim: TestbenchContext): while True: - while self.rand.random() < 0.5: - yield # to slow down execution + await self.random_wait_geom(sim, 0.5) # to slow down execution if len(self.to_execute_list) == 0: - yield Tick() + await sim.tick() else: idx = self.rand.randint(0, len(self.to_execute_list) - 1) rob_id, executed = self.to_execute_list.pop(idx) self.executed_list.append(executed) - yield from self.m.mark_done.call(rob_id) + await self.m.mark_done.call(sim, rob_id=rob_id, exception=0) - def do_retire(self): + async def do_retire(self, sim: TestbenchContext): cnt = 0 while True: if self.retire_queue.empty(): - self.m.retire.enable() - yield Tick() - is_ready = yield self.m.retire.adapter.done - assert is_ready == 0 # transaction should not be ready if there is nothing to retire + res = await self.m.retire.call_try(sim) + assert res is None # transaction should not be ready if there is nothing to retire else: regs, rob_id_exp = self.retire_queue.get() - results = yield from self.m.peek.call() - yield from self.m.retire.call() - phys_reg = results["rob_data"]["rp_dst"] - assert rob_id_exp == results["rob_id"] + results = await self.m.peek.call(sim) + await self.m.retire.call(sim) + phys_reg = results.rob_data.rp_dst + assert rob_id_exp == results.rob_id assert phys_reg in self.executed_list self.executed_list.remove(phys_reg) - yield Settle() - assert results["rob_data"] == regs + assert data_const_to_dict(results.rob_data) == regs self.regs_left_queue.put(phys_reg) cnt += 1 @@ -82,40 +76,38 @@ def test_single(self): self.log_regs = self.gen_params.isa.reg_cnt with self.run_simulation(m) as sim: - sim.add_process(self.gen_input) - sim.add_process(self.do_updates) - sim.add_process(self.do_retire) + sim.add_testbench(self.gen_input) + sim.add_testbench(self.do_updates, background=True) + sim.add_testbench(self.do_retire) class TestFullDoneCase(TestCaseWithSimulator): - def gen_input(self): + async def gen_input(self, sim: TestbenchContext): for _ in range(self.test_steps): log_reg = self.rand.randrange(self.log_regs) phys_reg = self.rand.randrange(self.phys_regs) - rob_id = yield from self.m.put.call(rl_dst=log_reg, rp_dst=phys_reg) + rob_id = (await self.m.put.call(sim, rl_dst=log_reg, rp_dst=phys_reg)).rob_id self.to_execute_list.append(rob_id) - def do_single_update(self): + async def do_single_update(self, sim: TestbenchContext): while len(self.to_execute_list) == 0: - yield Tick() + await sim.tick() rob_id = self.to_execute_list.pop(0) - yield from self.m.mark_done.call(rob_id) + await self.m.mark_done.call(sim, rob_id=rob_id) - def do_retire(self): + async def do_retire(self, sim: TestbenchContext): for i in range(self.test_steps - 1): - yield from self.do_single_update() + await self.do_single_update(sim) - yield from self.m.retire.call() - yield from self.do_single_update() + await self.m.retire.call(sim) + await self.do_single_update(sim) for i in range(self.test_steps - 1): - yield from self.m.retire.call() + await self.m.retire.call(sim) - yield from self.m.retire.enable() - yield Tick() - res = yield self.m.retire.adapter.done - assert res == 0 # should be disabled, since we have read all elements + res = await self.m.retire.call_try(sim) + assert res is None # since we have read all elements def test_single(self): self.rand = Random(0) @@ -130,5 +122,5 @@ def test_single(self): self.phys_regs = 2**self.gen_params.phys_regs_bits with self.run_simulation(m) as sim: - sim.add_process(self.gen_input) - sim.add_process(self.do_retire) + sim.add_testbench(self.gen_input) + sim.add_testbench(self.do_retire) diff --git a/test/frontend/test_decode_stage.py b/test/frontend/test_decode_stage.py index 8cfcb95fd..acab29abd 100644 --- a/test/frontend/test_decode_stage.py +++ b/test/frontend/test_decode_stage.py @@ -1,7 +1,7 @@ import pytest from transactron.lib import AdapterTrans, FIFO - -from transactron.testing import TestCaseWithSimulator, TestbenchIO, SimpleTestCircuit, ModuleConnector +from transactron.utils.amaranth_ext.elaboratables import ModuleConnector +from transactron.testing import TestCaseWithSimulator, TestbenchIO, SimpleTestCircuit, TestbenchContext from coreblocks.frontend.decoder.decode_stage import DecodeStage from coreblocks.params import GenParams @@ -32,10 +32,10 @@ def setup(self, fixture_initialize_testing_env): ) ) - def decode_test_proc(self): + async def decode_test_proc(self, sim: TestbenchContext): # testing an OP_IMM instruction (test copied from test_decoder.py) - yield from self.fifo_in_write.call(instr=0x02A28213) - decoded = yield from self.fifo_out_read.call() + await self.fifo_in_write.call(sim, instr=0x02A28213) + decoded = await self.fifo_out_read.call(sim) assert decoded["exec_fn"]["op_type"] == OpType.ARITHMETIC assert decoded["exec_fn"]["funct3"] == Funct3.ADD @@ -46,8 +46,8 @@ def decode_test_proc(self): assert decoded["imm"] == 42 # testing an OP instruction (test copied from test_decoder.py) - yield from self.fifo_in_write.call(instr=0x003100B3) - decoded = yield from self.fifo_out_read.call() + await self.fifo_in_write.call(sim, instr=0x003100B3) + decoded = await self.fifo_out_read.call(sim) assert decoded["exec_fn"]["op_type"] == OpType.ARITHMETIC assert decoded["exec_fn"]["funct3"] == Funct3.ADD @@ -57,8 +57,8 @@ def decode_test_proc(self): assert decoded["regs_l"]["rl_s2"] == 3 # testing an illegal - yield from self.fifo_in_write.call(instr=0x0) - decoded = yield from self.fifo_out_read.call() + await self.fifo_in_write.call(sim, instr=0x0) + decoded = await self.fifo_out_read.call(sim) assert decoded["exec_fn"]["op_type"] == OpType.EXCEPTION assert decoded["exec_fn"]["funct3"] == Funct3._EILLEGALINSTR @@ -67,8 +67,8 @@ def decode_test_proc(self): assert decoded["regs_l"]["rl_s1"] == 0 assert decoded["regs_l"]["rl_s2"] == 0 - yield from self.fifo_in_write.call(instr=0x0, access_fault=1) - decoded = yield from self.fifo_out_read.call() + await self.fifo_in_write.call(sim, instr=0x0, access_fault=1) + decoded = await self.fifo_out_read.call(sim) assert decoded["exec_fn"]["op_type"] == OpType.EXCEPTION assert decoded["exec_fn"]["funct3"] == Funct3._EINSTRACCESSFAULT @@ -79,4 +79,4 @@ def decode_test_proc(self): def test(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.decode_test_proc) + sim.add_testbench(self.decode_test_proc) diff --git a/test/frontend/test_fetch.py b/test/frontend/test_fetch.py index 33b216752..5e2406776 100644 --- a/test/frontend/test_fetch.py +++ b/test/frontend/test_fetch.py @@ -6,13 +6,20 @@ import random from amaranth import Elaboratable, Module -from amaranth.sim import Passive, Tick from coreblocks.interface.keys import FetchResumeKey from transactron.core import Method from transactron.lib import AdapterTrans, Adapter, BasicFifo +from transactron.testing.method_mock import MethodMock from transactron.utils import ModuleConnector -from transactron.testing import TestCaseWithSimulator, TestbenchIO, def_method_mock, SimpleTestCircuit, TestGen +from transactron.testing import ( + TestCaseWithSimulator, + TestbenchIO, + def_method_mock, + SimpleTestCircuit, + TestbenchContext, + ProcessContext, +) from coreblocks.frontend.fetch.fetch import FetchUnit, PredictionChecker from coreblocks.cache.iface import CacheInterface @@ -133,15 +140,12 @@ def gen_branch(self, offset: int, taken: bool): return self.add_instr(data, True, jump_offset=offset, branch_taken=taken) - def cache_process(self): - yield Passive() - + async def cache_process(self, sim: ProcessContext): while True: while len(self.input_q) == 0: - yield Tick() + await sim.tick() - while random.random() < 0.5: - yield Tick() + await self.random_wait_geom(sim, 0.5) req_addr = self.input_q.popleft() & ~(self.gen_params.fetch_block_bytes - 1) @@ -162,15 +166,24 @@ def load_or_gen_mem(addr): self.output_q.append({"fetch_block": fetch_block, "error": bad_addr}) - @def_method_mock(lambda self: self.icache.issue_req_io, enable=lambda self: len(self.input_q) < 2, sched_prio=1) + @def_method_mock( + lambda self: self.icache.issue_req_io, enable=lambda self: len(self.input_q) < 2 + ) # TODO had sched_prio def issue_req_mock(self, addr): - self.input_q.append(addr) + @MethodMock.effect + def eff(): + self.input_q.append(addr) @def_method_mock(lambda self: self.icache.accept_res_io, enable=lambda self: len(self.output_q) > 0) def accept_res_mock(self): - return self.output_q.popleft() + @MethodMock.effect + def eff(): + self.output_q.popleft() - def fetch_out_check(self): + if self.output_q: + return self.output_q[0] + + async def fetch_out_check(self, sim: TestbenchContext): while self.instr_queue: instr = self.instr_queue.popleft() @@ -178,7 +191,7 @@ def fetch_out_check(self): if not instr["rvc"]: access_fault |= instr["pc"] + 2 in self.memerr - v = yield from self.io_out.call() + v = await self.io_out.call(sim) assert v["pc"] == instr["pc"] assert v["access_fault"] == access_fault @@ -188,13 +201,13 @@ def fetch_out_check(self): assert v["instr"] == instr_data if (instr["jumps"] and (instr["branch_taken"] != v["predicted_taken"])) or access_fault: - yield from self.random_wait(5) - yield from self.fetch.stall_exception.call() - yield from self.random_wait(5) + await self.random_wait(sim, 5) + await self.fetch.stall_exception.call(sim) + await self.random_wait(sim, 5) # Empty the pipeline - yield from self.clean_fifo.call_try() - yield Tick() + await self.clean_fifo.call_try(sim) + await sim.tick() resume_pc = instr["next_pc"] if access_fault: @@ -204,13 +217,13 @@ def fetch_out_check(self): ) + self.gen_params.fetch_block_bytes # Resume the fetch unit - while (yield from self.fetch.resume_from_exception.call_try(pc=resume_pc)) is None: + while await self.fetch.resume_from_exception.call_try(sim, pc=resume_pc) is None: pass def run_sim(self): with self.run_simulation(self.m) as sim: sim.add_process(self.cache_process) - sim.add_process(self.fetch_out_check) + sim.add_testbench(self.fetch_out_check) def test_simple_no_jumps(self): for _ in range(50): @@ -390,7 +403,7 @@ def test_random(self): with self.run_simulation(self.m) as sim: sim.add_process(self.cache_process) - sim.add_process(self.fetch_out_check) + sim.add_testbench(self.fetch_out_check) @dataclass(frozen=True) @@ -424,8 +437,9 @@ def setup(self, fixture_initialize_testing_env): self.m = SimpleTestCircuit(PredictionChecker(self.gen_params)) - def check( + async def check( self, + sim: TestbenchContext, pc: int, block_cross: bool, predecoded: list[tuple[CfiType, int]], @@ -434,7 +448,7 @@ def check( cfi_type: CfiType, cfi_target: Optional[int], valid_mask: int = -1, - ) -> TestGen[CheckerResult]: + ) -> CheckerResult: # Fill the array with non-CFI instructions for _ in range(self.gen_params.fetch_width - len(predecoded)): predecoded.append((CfiType.INVALID, 0)) @@ -457,7 +471,8 @@ def check( instr_valid = (((1 << self.gen_params.fetch_width) - 1) << instr_start) & valid_mask - res = yield from self.m.check.call( + res = await self.m.check.call( + sim, fb_addr=pc >> self.gen_params.fetch_block_bytes_log, instr_block_cross=block_cross, instr_valid=instr_valid, @@ -493,46 +508,46 @@ def test_no_misprediction(self): instr_width = self.gen_params.min_instr_width_bytes fetch_width = self.gen_params.fetch_width - def proc(): + async def proc(sim: TestbenchContext): # No CFI at all - ret = yield from self.check(0x100, False, [], 0, 0, CfiType.INVALID, None) + ret = await self.check(sim, 0x100, False, [], 0, 0, CfiType.INVALID, None) self.assert_resp(ret, mispredicted=False) # There is one forward branch that we didn't predict - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, 100)], 0, 0, CfiType.INVALID, None) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0, 0, CfiType.INVALID, None) self.assert_resp(ret, mispredicted=False) # There are many forward branches that we didn't predict - ret = yield from self.check( - 0x100, False, [(CfiType.BRANCH, 100)] * fetch_width, 0, 0, CfiType.INVALID, None + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, 100)] * fetch_width, 0, 0, CfiType.INVALID, None ) self.assert_resp(ret, mispredicted=False) # There is a predicted JAL instr - ret = yield from self.check(0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 0x100 + 100) + ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 0x100 + 100) self.assert_resp(ret, mispredicted=False) # There is a predicted JALR instr - the predecoded offset can now be anything - ret = yield from self.check(0x100, False, [(CfiType.JALR, 200)], 0, 0, CfiType.JALR, 0x100 + 100) + ret = await self.check(sim, 0x100, False, [(CfiType.JALR, 200)], 0, 0, CfiType.JALR, 0x100 + 100) self.assert_resp(ret, mispredicted=False) # There is a forward taken-predicted branch - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 0x100 + 100) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 0x100 + 100) self.assert_resp(ret, mispredicted=False) # There is a backward taken-predicted branch - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100) self.assert_resp(ret, mispredicted=False) # Branch located between two fetch blocks if self.with_rvc: - ret = yield from self.check( - 0x100, True, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100 - 2 + ret = await self.check( + sim, 0x100, True, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100 - 2 ) self.assert_resp(ret, mispredicted=False) # One branch predicted as not taken - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.INVALID, 0) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.INVALID, 0) self.assert_resp(ret, mispredicted=False) # Now tests for fetch blocks with multiple instructions @@ -540,7 +555,8 @@ def proc(): return # Predicted taken branch as the second instruction - ret = yield from self.check( + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], @@ -552,13 +568,14 @@ def proc(): self.assert_resp(ret, mispredicted=False) # Predicted, but not taken branch as the second instruction - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b10, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b10, 0, CfiType.INVALID, 0 ) self.assert_resp(ret, mispredicted=False) if self.with_rvc: - ret = yield from self.check( + ret = await self.check( + sim, 0x100, True, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], @@ -569,7 +586,8 @@ def proc(): ) self.assert_resp(ret, mispredicted=False) - ret = yield from self.check( + ret = await self.check( + sim, 0x100, True, [(CfiType.JAL, 100), (CfiType.JAL, -100)], @@ -582,15 +600,16 @@ def proc(): self.assert_resp(ret, mispredicted=False) # Two branches with all possible combintations taken/not-taken - ret = yield from self.check( - 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.INVALID, 0 ) self.assert_resp(ret, mispredicted=False) - ret = yield from self.check( - 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.BRANCH, 0x100 - 100 + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.BRANCH, 0x100 - 100 ) self.assert_resp(ret, mispredicted=False) - ret = yield from self.check( + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], @@ -602,17 +621,25 @@ def proc(): self.assert_resp(ret, mispredicted=False) # JAL at the beginning, but we start from the second instruction - ret = yield from self.check(0x100 + instr_width, False, [(CfiType.JAL, -100)], 0b0, 0, CfiType.INVALID, 0) + ret = await self.check(sim, 0x100 + instr_width, False, [(CfiType.JAL, -100)], 0b0, 0, CfiType.INVALID, 0) self.assert_resp(ret, mispredicted=False) # JAL and a forward branch that we didn't predict - ret = yield from self.check( - 0x100 + instr_width, False, [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], 0b00, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, + 0x100 + instr_width, + False, + [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], + 0b00, + 0, + CfiType.INVALID, + 0, ) self.assert_resp(ret, mispredicted=False) # two JAL instructions, but we start from the second one - ret = yield from self.check( + ret = await self.check( + sim, 0x100 + instr_width, False, [(CfiType.JAL, -100), (CfiType.JAL, 100)], @@ -624,7 +651,8 @@ def proc(): self.assert_resp(ret, mispredicted=False) # JAL and a branch, but we start from the second instruction - ret = yield from self.check( + ret = await self.check( + sim, 0x100 + instr_width, False, [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], @@ -636,24 +664,24 @@ def proc(): self.assert_resp(ret, mispredicted=False) with self.run_simulation(self.m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) def test_preceding_redirection(self): instr_width = self.gen_params.min_instr_width_bytes fetch_width = self.gen_params.fetch_width - def proc(): + async def proc(sim: TestbenchContext): # No prediction was made, but there is a JAL at the beginning - ret = yield from self.check(0x100, False, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None) + ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 0x20) # The same, but the jump is between two fetch blocks if self.with_rvc: - ret = yield from self.check(0x100, True, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None) + ret = await self.check(sim, 0x100, True, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 0x20 - 2) # Not predicted backward branch - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, -100)], 0b0, 0, CfiType.INVALID, 0) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b0, 0, CfiType.INVALID, 0) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) # Now tests for fetch blocks with multiple instructions @@ -661,7 +689,8 @@ def proc(): return # We predicted the branch on the second instruction, but there's a JAL on the first one. - ret = yield from self.check( + ret = await self.check( + sim, 0x100, False, [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], @@ -673,7 +702,8 @@ def proc(): self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) # We predicted the branch on the second instruction, but there's a JALR on the first one. - ret = yield from self.check( + ret = await self.check( + sim, 0x100, False, [(CfiType.JALR, -100), (CfiType.BRANCH, 100)], @@ -685,7 +715,8 @@ def proc(): self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=0) # We predicted the branch on the second instruction, but there's a backward on the first one. - ret = yield from self.check( + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], @@ -697,31 +728,32 @@ def proc(): self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) # Unpredicted backward branch as the second instruction - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b00, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b00, 0, CfiType.INVALID, 0 ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width - 100 ) # Unpredicted JAL as the second instruction - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b00, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b00, 0, CfiType.INVALID, 0 ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100 ) # Unpredicted JALR as the second instruction - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.JALR, 100)], 0b00, 0, CfiType.INVALID, 0 + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JALR, 100)], 0b00, 0, CfiType.INVALID, 0 ) self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=1) if fetch_width < 3: return - ret = yield from self.check( + ret = await self.check( + sim, 0x100 + instr_width, False, [(CfiType.JAL, -100), (CfiType.INVALID, 100), (CfiType.JAL, 100)], @@ -735,94 +767,101 @@ def proc(): ) with self.run_simulation(self.m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) def test_mispredicted_cfi_type(self): instr_width = self.gen_params.min_instr_width_bytes fetch_width = self.gen_params.fetch_width fb_bytes = self.gen_params.fetch_block_bytes - def proc(): + async def proc(sim: TestbenchContext): # We predicted a JAL, but in fact there is a non-CFI instruction - ret = yield from self.check(0x100, False, [(CfiType.INVALID, 0)], 0, 0, CfiType.JAL, 100) + ret = await self.check(sim, 0x100, False, [(CfiType.INVALID, 0)], 0, 0, CfiType.JAL, 100) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes ) # We predicted a JAL, but in fact there is a branch - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, -100)], 0, 0, CfiType.JAL, 100) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0, 0, CfiType.JAL, 100) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) # We predicted a JAL, but in fact there is a JALR instruction - ret = yield from self.check(0x100, False, [(CfiType.JALR, -100)], 0, 0, CfiType.JAL, 100) + ret = await self.check(sim, 0x100, False, [(CfiType.JALR, -100)], 0, 0, CfiType.JAL, 100) self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=0) # We predicted a branch, but in fact there is a JAL - ret = yield from self.check(0x100, False, [(CfiType.JAL, -100)], 0b1, 0, CfiType.BRANCH, 100) + ret = await self.check(sim, 0x100, False, [(CfiType.JAL, -100)], 0b1, 0, CfiType.BRANCH, 100) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) if fetch_width < 2: return # There is a branch and a non-CFI, but we predicted two branches - ret = yield from self.check( - 0x100, False, [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], 0b11, 1, CfiType.BRANCH, 100 + ret = await self.check( + sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], 0b11, 1, CfiType.BRANCH, 100 ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes ) # The same as above, but we start from the second instruction - ret = yield from self.check( - 0x100 + instr_width, False, [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], 0b11, 1, CfiType.BRANCH, 100 + ret = await self.check( + sim, + 0x100 + instr_width, + False, + [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], + 0b11, + 1, + CfiType.BRANCH, + 100, ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes ) with self.run_simulation(self.m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) def test_mispredicted_cfi_target(self): instr_width = self.gen_params.min_instr_width_bytes fetch_width = self.gen_params.fetch_width - def proc(): + async def proc(sim: TestbenchContext): # We predicted a wrong JAL target - ret = yield from self.check(0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 200) + ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 200) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100) # We predicted a wrong branch target - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 200) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 200) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100) # We didn't provide the branch target - ret = yield from self.check(0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, None) + ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, None) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100) # We predicted a wrong JAL target that is between two fetch blocks if self.with_rvc: - ret = yield from self.check(0x100, True, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 300) + ret = await self.check(sim, 0x100, True, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 300) self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100 - 2) if fetch_width < 2: return # The second instruction is a branch without the target - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, 100)], 0b10, 1, CfiType.BRANCH, None + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, 100)], 0b10, 1, CfiType.BRANCH, None ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100 ) # The second instruction is a JAL with a wrong target - ret = yield from self.check( - 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b10, 1, CfiType.JAL, 200 + ret = await self.check( + sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b10, 1, CfiType.JAL, 200 ) self.assert_resp( ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100 ) with self.run_simulation(self.m) as sim: - sim.add_process(proc) + sim.add_testbench(proc) diff --git a/test/frontend/test_instr_decoder.py b/test/frontend/test_instr_decoder.py index 09de8ca91..f7e82125c 100644 --- a/test/frontend/test_instr_decoder.py +++ b/test/frontend/test_instr_decoder.py @@ -1,6 +1,6 @@ from amaranth.sim import * -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, TestbenchContext from coreblocks.params import * from coreblocks.params.configurations import test_core_config @@ -194,64 +194,63 @@ def setup_method(self): self.cnt = 1 def do_test(self, tests: list[InstrTest]): - def process(): + async def process(sim: TestbenchContext): for test in tests: - yield self.decoder.instr.eq(test.encoding) - yield Settle() + sim.set(self.decoder.instr, test.encoding) - assert (yield self.decoder.illegal) == test.illegal + assert sim.get(self.decoder.illegal) == test.illegal if test.illegal: return - assert (yield self.decoder.opcode) == test.opcode + assert sim.get(self.decoder.opcode) == test.opcode if test.funct3 is not None: - assert (yield self.decoder.funct3) == test.funct3 - assert (yield self.decoder.funct3_v) == (test.funct3 is not None) + assert sim.get(self.decoder.funct3) == test.funct3 + assert sim.get(self.decoder.funct3_v) == (test.funct3 is not None) if test.funct7 is not None: - assert (yield self.decoder.funct7) == test.funct7 - assert (yield self.decoder.funct7_v) == (test.funct7 is not None) + assert sim.get(self.decoder.funct7) == test.funct7 + assert sim.get(self.decoder.funct7_v) == (test.funct7 is not None) if test.funct12 is not None: - assert (yield self.decoder.funct12) == test.funct12 - assert (yield self.decoder.funct12_v) == (test.funct12 is not None) + assert sim.get(self.decoder.funct12) == test.funct12 + assert sim.get(self.decoder.funct12_v) == (test.funct12 is not None) if test.rd is not None: - assert (yield self.decoder.rd) == test.rd - assert (yield self.decoder.rd_v) == (test.rd is not None) + assert sim.get(self.decoder.rd) == test.rd + assert sim.get(self.decoder.rd_v) == (test.rd is not None) if test.rs1 is not None: - assert (yield self.decoder.rs1) == test.rs1 - assert (yield self.decoder.rs1_v) == (test.rs1 is not None) + assert sim.get(self.decoder.rs1) == test.rs1 + assert sim.get(self.decoder.rs1_v) == (test.rs1 is not None) if test.rs2 is not None: - assert (yield self.decoder.rs2) == test.rs2 - assert (yield self.decoder.rs2_v) == (test.rs2 is not None) + assert sim.get(self.decoder.rs2) == test.rs2 + assert sim.get(self.decoder.rs2_v) == (test.rs2 is not None) if test.imm is not None: if test.csr is not None: # in CSR instruction additional fields are passed in unused bits of imm field - assert (yield self.decoder.imm.as_signed() & ((2**5) - 1)) == test.imm + assert sim.get(self.decoder.imm.as_signed() & ((2**5) - 1)) == test.imm else: - assert (yield self.decoder.imm.as_signed()) == test.imm + assert sim.get(self.decoder.imm.as_signed()) == test.imm if test.succ is not None: - assert (yield self.decoder.succ) == test.succ + assert sim.get(self.decoder.succ) == test.succ if test.pred is not None: - assert (yield self.decoder.pred) == test.pred + assert sim.get(self.decoder.pred) == test.pred if test.fm is not None: - assert (yield self.decoder.fm) == test.fm + assert sim.get(self.decoder.fm) == test.fm if test.csr is not None: - assert (yield self.decoder.csr) == test.csr + assert sim.get(self.decoder.csr) == test.csr - assert (yield self.decoder.optype) == test.op + assert sim.get(self.decoder.optype) == test.op with self.run_simulation(self.decoder) as sim: - sim.add_process(process) + sim.add_testbench(process) def test_i(self): self.do_test(self.DECODER_TESTS_I) @@ -296,14 +295,13 @@ def test_e(self): self.gen_params = GenParams(test_core_config.replace(embedded=True, _implied_extensions=Extension.E)) self.decoder = InstrDecoder(self.gen_params) - def process(): + async def process(sim: TestbenchContext): for encoding, illegal in self.E_TEST: - yield self.decoder.instr.eq(encoding) - yield Settle() - assert (yield self.decoder.illegal) == illegal + sim.set(self.decoder.instr, encoding) + assert sim.get(self.decoder.illegal) == illegal with self.run_simulation(self.decoder) as sim: - sim.add_process(process) + sim.add_testbench(process) class TestEncodingUniqueness(TestCase): diff --git a/test/frontend/test_rvc.py b/test/frontend/test_rvc.py index f1690f8dd..53dcaebd0 100644 --- a/test/frontend/test_rvc.py +++ b/test/frontend/test_rvc.py @@ -1,6 +1,5 @@ from parameterized import parameterized_class -from amaranth.sim import Settle, Tick from amaranth import * from coreblocks.frontend.decoder.rvc import InstrDecompress @@ -9,7 +8,7 @@ from coreblocks.params.configurations import test_core_config from transactron.utils import ValueLike -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, TestbenchContext COMMON_TESTS = [ # Illegal instruction @@ -283,22 +282,18 @@ def test(self): ) self.m = InstrDecompress(self.gen_params) - def process(): - illegal = Signal(32) - yield illegal.eq(IllegalInstr()) + async def process(sim: TestbenchContext): + illegal = Const.cast(IllegalInstr()).value for instr_in, instr_out in self.test_cases: - yield self.m.instr_in.eq(instr_in) - expected = Signal(32) - yield expected.eq(instr_out) - yield Settle() + sim.set(self.m.instr_in, instr_in) + expected = Const.cast(instr_out).value - if (yield expected) == (yield illegal): - yield expected.eq(instr_in) # for exception handling - yield Settle() + if expected == illegal: + expected = instr_in # for exception handling - assert (yield self.m.instr_out) == (yield expected) - yield Tick() + assert sim.get(self.m.instr_out) == expected + await sim.tick() with self.run_simulation(self.m) as sim: - sim.add_process(process) + sim.add_testbench(process) diff --git a/test/func_blocks/csr/test_csr.py b/test/func_blocks/csr/test_csr.py index 340afeda9..6fa8c95e7 100644 --- a/test/func_blocks/csr/test_csr.py +++ b/test/func_blocks/csr/test_csr.py @@ -1,5 +1,5 @@ from amaranth import * -from random import random +import random from transactron.lib import Adapter from transactron.core.tmodule import TModule @@ -17,13 +17,14 @@ CSRInstancesKey, ) from coreblocks.arch.isa_consts import PrivilegeLevel +from transactron.lib.adapters import AdapterTrans from transactron.utils.dependencies import DependencyContext from transactron.testing import * class CSRUnitTestCircuit(Elaboratable): - def __init__(self, gen_params, csr_count, only_legal=True): + def __init__(self, gen_params: GenParams, csr_count: int, only_legal=True): self.gen_params = gen_params self.csr_count = csr_count self.only_legal = only_legal @@ -32,7 +33,12 @@ def elaborate(self, platform): m = Module() m.submodules.precommit = self.precommit = TestbenchIO( - Adapter(o=self.gen_params.get(RetirementLayouts).precommit, nonexclusive=True) + Adapter( + i=self.gen_params.get(RetirementLayouts).precommit_in, + o=self.gen_params.get(RetirementLayouts).precommit_out, + nonexclusive=True, + combiner=lambda m, args, runs: args[0], + ).set(with_validate_arguments=True) ) DependencyContext.get().add_dependency(InstructionPrecommitKey(), self.precommit.adapter.iface) @@ -72,8 +78,8 @@ def make_csr(number: int): class TestCSRUnit(TestCaseWithSimulator): - def gen_expected_out(self, op, rd, rs1, operand_val, csr): - exp_read = {"rp_dst": rd, "result": (yield self.dut.csr[csr].value)} + def gen_expected_out(self, sim: TestbenchContext, op: Funct3, rd: int, rs1: int, operand_val: int, csr: int): + exp_read = {"rp_dst": rd, "result": sim.get(self.dut.csr[csr].value)} rs1_val = {"rp_s1": rs1, "value": operand_val} exp_write = {} @@ -84,11 +90,11 @@ def gen_expected_out(self, op, rd, rs1, operand_val, csr): elif (op == Funct3.CSRRS and rs1) or op == Funct3.CSRRSI: exp_write = {"csr": csr, "value": exp_read["result"] | operand_val} else: - exp_write = {"csr": csr, "value": (yield self.dut.csr[csr].value)} + exp_write = {"csr": csr, "value": sim.get(self.dut.csr[csr].value)} return {"exp_read": exp_read, "exp_write": exp_write, "rs1": rs1_val} - def generate_instruction(self): + def generate_instruction(self, sim: TestbenchContext): ops = [ Funct3.CSRRW, Funct3.CSRRC, @@ -108,7 +114,7 @@ def generate_instruction(self): operand_val = imm if imm_op else rs1_val csr = random.choice(list(self.dut.csr.keys())) - exp = yield from self.gen_expected_out(op, rd, rs1, operand_val, csr) + exp = self.gen_expected_out(sim, op, rd, rs1, operand_val, csr) value_available = random.random() < 0.2 @@ -125,34 +131,38 @@ def generate_instruction(self): "exp": exp, } - def process_test(self): - yield from self.dut.fetch_resume.enable() - yield from self.dut.exception_report.enable() + async def process_test(self, sim: TestbenchContext): + self.dut.fetch_resume.enable(sim) + self.dut.exception_report.enable(sim) for _ in range(self.cycles): - yield from self.random_wait_geom() + await self.random_wait_geom(sim) - op = yield from self.generate_instruction() + op = self.generate_instruction(sim) - yield from self.dut.select.call() + await self.dut.select.call(sim) - yield from self.dut.insert.call(rs_data=op["instr"]) + await self.dut.insert.call(sim, rs_data=op["instr"]) - yield from self.random_wait_geom() + await self.random_wait_geom(sim) if op["exp"]["rs1"]["rp_s1"]: - yield from self.dut.update.call(reg_id=op["exp"]["rs1"]["rp_s1"], reg_val=op["exp"]["rs1"]["value"]) + await self.dut.update.call(sim, reg_id=op["exp"]["rs1"]["rp_s1"], reg_val=op["exp"]["rs1"]["value"]) - yield from self.random_wait_geom() - yield from self.dut.precommit.call(side_fx=1) + await self.random_wait_geom(sim) + # TODO: this is a hack, a real method mock should be used + for _, r in self.dut.precommit.adapter.validators: # type: ignore + sim.set(r, 1) + self.dut.precommit.call_init(sim, side_fx=1) # TODO: sensible precommit handling - yield from self.random_wait_geom() - res = yield from self.dut.accept.call() + await self.random_wait_geom(sim) + res, resume_res = await CallTrigger(sim).call(self.dut.accept).sample(self.dut.fetch_resume).until_done() + self.dut.precommit.disable(sim) - assert self.dut.fetch_resume.done() - assert res["rp_dst"] == op["exp"]["exp_read"]["rp_dst"] + assert res is not None and resume_res is not None + assert res.rp_dst == op["exp"]["exp_read"]["rp_dst"] if op["exp"]["exp_read"]["rp_dst"]: - assert res["result"] == op["exp"]["exp_read"]["result"] - assert (yield self.dut.csr[op["exp"]["exp_write"]["csr"]].value) == op["exp"]["exp_write"]["value"] - assert res["exception"] == 0 + assert res.result == op["exp"]["exp_read"]["result"] + assert sim.get(self.dut.csr[op["exp"]["exp_write"]["csr"]].value) == op["exp"]["exp_write"]["value"] + assert res.exception == 0 def test_randomized(self): self.gen_params = GenParams(test_core_config) @@ -164,7 +174,7 @@ def test_randomized(self): self.dut = CSRUnitTestCircuit(self.gen_params, self.csr_count) with self.run_simulation(self.dut) as sim: - sim.add_process(self.process_test) + sim.add_testbench(self.process_test) exception_csr_numbers = [ 0xCC0, # read_only @@ -172,21 +182,22 @@ def test_randomized(self): 0x7FE, # missing priv ] - def process_exception_test(self): - yield from self.dut.fetch_resume.enable() - yield from self.dut.exception_report.enable() + async def process_exception_test(self, sim: TestbenchContext): + self.dut.fetch_resume.enable(sim) + self.dut.exception_report.enable(sim) for csr in self.exception_csr_numbers: if csr == 0x7FE: - yield from self.dut.priv_io.call(data=PrivilegeLevel.USER) + await self.dut.priv_io.call(sim, data=PrivilegeLevel.USER) else: - yield from self.dut.priv_io.call(data=PrivilegeLevel.MACHINE) + await self.dut.priv_io.call(sim, data=PrivilegeLevel.MACHINE) - yield from self.random_wait_geom() + await self.random_wait_geom(sim) - yield from self.dut.select.call() + await self.dut.select.call(sim) rob_id = random.randrange(2**self.gen_params.rob_entries_bits) - yield from self.dut.insert.call( + await self.dut.insert.call( + sim, rs_data={ "exec_fn": {"op_type": OpType.CSR_REG, "funct3": Funct3.CSRRW, "funct7": 0}, "rp_s1": 0, @@ -196,20 +207,24 @@ def process_exception_test(self): "imm": 0, "csr": csr, "rob_id": rob_id, - } + }, ) - yield from self.random_wait_geom() - yield from self.dut.precommit.call(rob_id=rob_id, side_fx=1) + await self.random_wait_geom(sim) + # TODO: this is a hack, a real method mock should be used + for _, r in self.dut.precommit.adapter.validators: # type: ignore + sim.set(r, 1) + self.dut.precommit.call_init(sim, side_fx=1) - yield from self.random_wait_geom() - res = yield from self.dut.accept.call() + await self.random_wait_geom(sim) + res, report = await CallTrigger(sim).call(self.dut.accept).sample(self.dut.exception_report).until_done() + self.dut.precommit.disable(sim) assert res["exception"] == 1 - report = yield from self.dut.exception_report.call_result() - assert isinstance(report, dict) - report.pop("mtval") # mtval tested in mtval.asm test - assert {"rob_id": rob_id, "cause": ExceptionCause.ILLEGAL_INSTRUCTION, "pc": 0} == report + assert report is not None + report_dict = data_const_to_dict(report) + report_dict.pop("mtval") # mtval tested in mtval.asm test + assert {"rob_id": rob_id, "cause": ExceptionCause.ILLEGAL_INSTRUCTION, "pc": 0} == report_dict def test_exception(self): self.gen_params = GenParams(test_core_config) @@ -218,13 +233,13 @@ def test_exception(self): self.dut = CSRUnitTestCircuit(self.gen_params, 0, only_legal=False) with self.run_simulation(self.dut) as sim: - sim.add_process(self.process_exception_test) + sim.add_testbench(self.process_exception_test) class TestCSRRegister(TestCaseWithSimulator): - def randomized_process_test(self): + async def randomized_process_test(self, sim: TestbenchContext): # always enabled - yield from self.dut.read.enable() + self.dut.read.enable(sim) previous_data = 0 for _ in range(self.cycles): @@ -236,7 +251,7 @@ def randomized_process_test(self): if random.random() < 0.9: write = True exp_write_data = random.randint(0, 2**self.gen_params.isa.xlen - 1) - yield from self.dut.write.call_init(data=exp_write_data) + self.dut.write.call_init(sim, data=exp_write_data) if random.random() < 0.3: fu_write = True @@ -245,33 +260,32 @@ def randomized_process_test(self): exp_write_data = (write_arg & ~self.ro_mask) | ( (exp_write_data if exp_write_data is not None else previous_data) & self.ro_mask ) - yield from self.dut._fu_write.call_init(data=write_arg) + self.dut._fu_write.call_init(sim, data=write_arg) if random.random() < 0.2: fu_read = True - yield from self.dut._fu_read.enable() + self.dut._fu_read.call_init(sim) - yield Tick() - yield Settle() + await sim.tick() exp_read_data = exp_write_data if fu_write or write else previous_data if fu_read: # in CSRUnit this call is called before write and returns previous result - assert (yield from self.dut._fu_read.call_result()) == {"data": exp_read_data} + assert data_const_to_dict(self.dut._fu_read.get_call_result(sim)) == {"data": exp_read_data} - assert (yield from self.dut.read.call_result()) == { + assert data_const_to_dict(self.dut.read.get_call_result(sim)) == { "data": exp_read_data, "read": int(fu_read), "written": int(fu_write), } - read_result = yield from self.dut.read.call_result() + read_result = self.dut.read.get_call_result(sim) assert read_result is not None - previous_data = read_result["data"] + previous_data = read_result.data - yield from self.dut._fu_read.disable() - yield from self.dut._fu_write.disable() - yield from self.dut.write.disable() + self.dut._fu_read.disable(sim) + self.dut._fu_write.disable(sim) + self.dut.write.disable(sim) def test_randomized(self): self.gen_params = GenParams(test_core_config) @@ -283,15 +297,15 @@ def test_randomized(self): self.dut = SimpleTestCircuit(CSRRegister(0, self.gen_params, ro_bits=self.ro_mask)) with self.run_simulation(self.dut) as sim: - sim.add_process(self.randomized_process_test) + sim.add_testbench(self.randomized_process_test) - def filtermap_process_test(self): + async def filtermap_process_test(self, sim: TestbenchContext): prev_value = 0 for _ in range(50): input = random.randrange(0, 2**34) - yield from self.dut._fu_write.call({"data": input}) - output = (yield from self.dut._fu_read.call())["data"] + await self.dut._fu_write.call(sim, data=input) + output = (await self.dut._fu_read.call(sim))["data"] expected = prev_value if input & 1: @@ -331,43 +345,46 @@ def write_filtermap(m: TModule, v: Value): ro_bits=(1 << 32), fu_read_map=lambda _, v: v << 1, fu_write_filtermap=write_filtermap, - ) + ), ) with self.run_simulation(self.dut) as sim: - sim.add_process(self.filtermap_process_test) - - def comb_process_test(self): - yield from self.dut.read.enable() - yield from self.dut.read_comb.enable() - yield from self.dut._fu_read.enable() - - yield from self.dut._fu_write.call_init({"data": 0xFFFF}) - yield from self.dut._fu_write.call_do() - assert (yield from self.dut.read_comb.call_result())["data"] == 0xFFFF - assert (yield from self.dut._fu_read.call_result())["data"] == 0xAB - yield Tick() - assert (yield from self.dut.read.call_result())["data"] == 0xFFFB - assert (yield from self.dut._fu_read.call_result())["data"] == 0xFFFB - yield Tick() - - yield from self.dut._fu_write.call_init({"data": 0x0FFF}) - yield from self.dut.write.call_init({"data": 0xAAAA}) - yield from self.dut._fu_write.call_do() - yield from self.dut.write.call_do() - assert (yield from self.dut.read_comb.call_result()) == {"data": 0x0FFF, "read": 1, "written": 1} - yield Tick() - assert (yield from self.dut._fu_read.call_result())["data"] == 0xAAAA - yield Tick() + sim.add_testbench(self.filtermap_process_test) + + async def comb_process_test(self, sim: TestbenchContext): + self.dut.read.enable(sim) + self.dut.read_comb.enable(sim) + self.dut._fu_read.enable(sim) + + self.dut._fu_write.call_init(sim, data=0xFFFF) + while self.dut._fu_write.get_call_result(sim) is None: + await sim.tick() + assert self.dut.read_comb.get_call_result(sim).data == 0xFFFF + assert self.dut._fu_read.get_call_result(sim).data == 0xAB + await sim.tick() + assert self.dut.read.get_call_result(sim)["data"] == 0xFFFB + assert self.dut._fu_read.get_call_result(sim)["data"] == 0xFFFB + await sim.tick() + + self.dut._fu_write.call_init(sim, data=0x0FFF) + self.dut.write.call_init(sim, data=0xAAAA) + while self.dut._fu_write.get_call_result(sim) is None or self.dut.write.get_call_result(sim) is None: + await sim.tick() + assert data_const_to_dict(self.dut.read_comb.get_call_result(sim)) == {"data": 0x0FFF, "read": 1, "written": 1} + await sim.tick() + assert self.dut._fu_read.get_call_result(sim).data == 0xAAAA + await sim.tick() # single cycle - yield from self.dut._fu_write.call_init({"data": 0x0BBB}) - yield from self.dut._fu_write.call_do() - update_val = (yield from self.dut.read_comb.call_result())["data"] | 0xD000 - yield from self.dut.write.call_init({"data": update_val}) - yield from self.dut.write.call_do() - yield Tick() - assert (yield from self.dut._fu_read.call_result())["data"] == 0xDBBB + self.dut._fu_write.call_init(sim, data=0x0BBB) + while self.dut._fu_write.get_call_result(sim) is None: + await sim.tick() + update_val = self.dut.read_comb.get_call_result(sim).data | 0xD000 + self.dut.write.call_init(sim, data=update_val) + while self.dut.write.get_call_result(sim) is None: + await sim.tick() + await sim.tick() + assert self.dut._fu_read.get_call_result(sim).data == 0xDBBB def test_comb(self): gen_params = GenParams(test_core_config) @@ -377,4 +394,4 @@ def test_comb(self): self.dut = SimpleTestCircuit(CSRRegister(None, gen_params, ro_bits=0b1111, fu_write_priority=False, init=0xAB)) with self.run_simulation(self.dut) as sim: - sim.add_process(self.comb_process_test) + sim.add_testbench(self.comb_process_test) diff --git a/test/func_blocks/fu/common/test_rs.py b/test/func_blocks/fu/common/test_rs.py index 222041a2a..7d311dede 100644 --- a/test/func_blocks/fu/common/test_rs.py +++ b/test/func_blocks/fu/common/test_rs.py @@ -2,15 +2,14 @@ from collections import deque from parameterized import parameterized_class -from amaranth.sim import Settle, Tick - -from transactron.testing import TestCaseWithSimulator, get_outputs, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit, TestbenchContext from coreblocks.func_blocks.fu.common.rs import RS, RSBase from coreblocks.func_blocks.fu.common.fifo_rs import FifoRS from coreblocks.params import * from coreblocks.params.configurations import test_core_config from coreblocks.arch import OpType +from transactron.testing.functions import data_const_to_dict def create_check_list(rs_entries_bits: int, insert_list: list[dict]) -> list[dict]: @@ -35,7 +34,7 @@ def create_data_list(gen_params: GenParams, count: int): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": k, "s2_val": k, @@ -75,35 +74,35 @@ def test_rs(self): self.finished = False with self.run_simulation(self.m) as sim: - sim.add_process(self.select_process) - sim.add_process(self.insert_process) - sim.add_process(self.update_process) - sim.add_process(self.take_process) + sim.add_testbench(self.select_process) + sim.add_testbench(self.insert_process) + sim.add_testbench(self.update_process) + sim.add_testbench(self.take_process) - def select_process(self): + async def select_process(self, sim: TestbenchContext): for k in range(len(self.data_list)): - rs_entry_id = (yield from self.m.select.call())["rs_entry_id"] + rs_entry_id = (await self.m.select.call(sim)).rs_entry_id self.select_queue.appendleft(rs_entry_id) self.rs_entries[rs_entry_id] = k - def insert_process(self): + async def insert_process(self, sim: TestbenchContext): for data in self.data_list: - yield Settle() # so that select_process can insert into the queue + await sim.delay(1e-9) # so that select_process can insert into the queue while not self.select_queue: - yield Tick() - yield Settle() + await sim.tick() + await sim.delay(1e-9) rs_entry_id = self.select_queue.pop() - yield from self.m.insert.call({"rs_entry_id": rs_entry_id, "rs_data": data}) + await self.m.insert.call(sim, rs_entry_id=rs_entry_id, rs_data=data) if data["rp_s1"]: self.regs_to_update.add(data["rp_s1"]) if data["rp_s2"]: self.regs_to_update.add(data["rp_s2"]) - def update_process(self): + async def update_process(self, sim: TestbenchContext): while not self.finished: - yield Settle() # so that insert_process can insert into the set + await sim.delay(1e-9) # so that insert_process can insert into the set if not self.regs_to_update: - yield Tick() + await sim.tick() continue reg_id = random.choice(list(self.regs_to_update)) self.regs_to_update.discard(reg_id) @@ -115,29 +114,26 @@ def update_process(self): if self.data_list[k]["rp_s2"] == reg_id: self.data_list[k]["rp_s2"] = 0 self.data_list[k]["s2_val"] = reg_val - yield from self.m.update.call(reg_id=reg_id, reg_val=reg_val) + await self.m.update.call(sim, reg_id=reg_id, reg_val=reg_val) - def take_process(self): + async def take_process(self, sim: TestbenchContext): taken: set[int] = set() - yield from self.m.get_ready_list[0].call_init() - yield Settle() + self.m.get_ready_list[0].call_init(sim) for k in range(len(self.data_list)): - yield Settle() - while not (yield from self.m.get_ready_list[0].done()): - yield Tick() - ready_list = (yield from self.m.get_ready_list[0].call_result())["ready_list"] + while not self.m.get_ready_list[0].get_done(sim): + await sim.tick() + ready_list = (self.m.get_ready_list[0].get_call_result(sim)).ready_list possible_ids = [i for i in range(2**self.rs_entries_bits) if ready_list & (1 << i)] - if not possible_ids: - yield Tick() - continue + while not possible_ids: + await sim.tick() rs_entry_id = random.choice(possible_ids) k = self.rs_entries[rs_entry_id] taken.add(k) test_data = dict(self.data_list[k]) del test_data["rp_s1"] del test_data["rp_s2"] - data = yield from self.m.take.call(rs_entry_id=rs_entry_id) - assert data == test_data + data = await self.m.take.call(sim, rs_entry_id=rs_entry_id) + assert data_const_to_dict(data) == test_data assert taken == set(range(len(self.data_list))) self.finished = True @@ -158,7 +154,7 @@ def test_insert(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -171,20 +167,18 @@ def test_insert(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # After each insert, entry should be marked as full for index, record in enumerate(self.insert_list): - assert (yield self.m._dut.data[index].rec_full) == 0 - yield from self.m.insert.call(record) - yield Settle() - assert (yield self.m._dut.data[index].rec_full) == 1 - yield Settle() + assert sim.get(self.m._dut.data[index].rec_full) == 0 + await self.m.insert.call(sim, record) + assert sim.get(self.m._dut.data[index].rec_full) == 1 # Check data integrity for expected, record in zip(self.check_list, self.m._dut.data): - assert expected == (yield from get_outputs(record)) + assert expected == data_const_to_dict(sim.get(record)) class TestRSMethodSelect(TestCaseWithSimulator): @@ -203,7 +197,7 @@ def test_select(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -216,38 +210,33 @@ def test_select(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # In the beginning the select method should be ready and id should be selectable for index, record in enumerate(self.insert_list): - assert (yield self.m._dut.select.ready) == 1 - assert (yield from self.m.select.call())["rs_entry_id"] == index - yield Settle() - assert (yield self.m._dut.data[index].rec_reserved) == 1 - yield from self.m.insert.call(record) - yield Settle() + assert sim.get(self.m._dut.select.ready) == 1 + assert (await self.m.select.call(sim)).rs_entry_id == index + assert sim.get(self.m._dut.data[index].rec_reserved) == 1 + await self.m.insert.call(sim, record) # Check if RS state is as expected for expected, record in zip(self.check_list, self.m._dut.data): - assert (yield record.rec_full) == expected["rec_full"] - assert (yield record.rec_reserved) == expected["rec_reserved"] + assert sim.get(record.rec_full) == expected["rec_full"] + assert sim.get(record.rec_reserved) == expected["rec_reserved"] # Reserve the last entry, then select ready should be false - assert (yield self.m._dut.select.ready) == 1 - assert (yield from self.m.select.call())["rs_entry_id"] == 3 - yield Settle() - assert (yield self.m._dut.select.ready) == 0 + assert sim.get(self.m._dut.select.ready) == 1 + assert (await self.m.select.call(sim)).rs_entry_id == 3 + assert sim.get(self.m._dut.select.ready) == 0 # After take, select ready should be true, with 0 index returned - yield from self.m.take.call(rs_entry_id=0) - yield Settle() - assert (yield self.m._dut.select.ready) == 1 - assert (yield from self.m.select.call())["rs_entry_id"] == 0 + await self.m.take.call(sim, rs_entry_id=0) + assert sim.get(self.m._dut.select.ready) == 1 + assert (await self.m.select.call(sim)).rs_entry_id == 0 # After reservation, select is false again - yield Settle() - assert (yield self.m._dut.select.ready) == 0 + assert sim.get(self.m._dut.select.ready) == 0 class TestRSMethodUpdate(TestCaseWithSimulator): @@ -266,7 +255,7 @@ def test_update(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -279,34 +268,31 @@ def test_update(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # Insert all reacords for record in self.insert_list: - yield from self.m.insert.call(record) - yield Settle() + await self.m.insert.call(sim, record) # Check data integrity for expected, record in zip(self.check_list, self.m._dut.data): - assert expected == (yield from get_outputs(record)) + assert expected == data_const_to_dict(sim.get(record)) # Update second entry first SP, instruction should be not ready value_sp1 = 1010 - assert (yield self.m._dut.data_ready[1]) == 0 - yield from self.m.update.call(reg_id=2, reg_val=value_sp1) - yield Settle() - assert (yield self.m._dut.data[1].rs_data.rp_s1) == 0 - assert (yield self.m._dut.data[1].rs_data.s1_val) == value_sp1 - assert (yield self.m._dut.data_ready[1]) == 0 + assert sim.get(self.m._dut.data_ready[1]) == 0 + await self.m.update.call(sim, reg_id=2, reg_val=value_sp1) + assert sim.get(self.m._dut.data[1].rs_data.rp_s1) == 0 + assert sim.get(self.m._dut.data[1].rs_data.s1_val) == value_sp1 + assert sim.get(self.m._dut.data_ready[1]) == 0 # Update second entry second SP, instruction should be ready value_sp2 = 2020 - yield from self.m.update.call(reg_id=3, reg_val=value_sp2) - yield Settle() - assert (yield self.m._dut.data[1].rs_data.rp_s2) == 0 - assert (yield self.m._dut.data[1].rs_data.s2_val) == value_sp2 - assert (yield self.m._dut.data_ready[1]) == 1 + await self.m.update.call(sim, reg_id=3, reg_val=value_sp2) + assert sim.get(self.m._dut.data[1].rs_data.rp_s2) == 0 + assert sim.get(self.m._dut.data[1].rs_data.s2_val) == value_sp2 + assert sim.get(self.m._dut.data_ready[1]) == 1 # Insert new instruction to entries 0 and 1, check if update of multiple registers works reg_id = 4 @@ -319,7 +305,7 @@ def simulation_process(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": 0, "s2_val": 0, @@ -327,18 +313,16 @@ def simulation_process(self): } for index in range(2): - yield from self.m.insert.call(rs_entry_id=index, rs_data=data) - yield Settle() - assert (yield self.m._dut.data_ready[index]) == 0 + await self.m.insert.call(sim, rs_entry_id=index, rs_data=data) + assert sim.get(self.m._dut.data_ready[index]) == 0 - yield from self.m.update.call(reg_id=reg_id, reg_val=value_spx) - yield Settle() + await self.m.update.call(sim, reg_id=reg_id, reg_val=value_spx) for index in range(2): - assert (yield self.m._dut.data[index].rs_data.rp_s1) == 0 - assert (yield self.m._dut.data[index].rs_data.rp_s2) == 0 - assert (yield self.m._dut.data[index].rs_data.s1_val) == value_spx - assert (yield self.m._dut.data[index].rs_data.s2_val) == value_spx - assert (yield self.m._dut.data_ready[index]) == 1 + assert sim.get(self.m._dut.data[index].rs_data.rp_s1) == 0 + assert sim.get(self.m._dut.data[index].rs_data.rp_s2) == 0 + assert sim.get(self.m._dut.data[index].rs_data.s1_val) == value_spx + assert sim.get(self.m._dut.data[index].rs_data.s2_val) == value_spx + assert sim.get(self.m._dut.data_ready[index]) == 1 class TestRSMethodTake(TestCaseWithSimulator): @@ -357,7 +341,7 @@ def test_take(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -370,37 +354,33 @@ def test_take(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # After each insert, entry should be marked as full for record in self.insert_list: - yield from self.m.insert.call(record) - yield Settle() + await self.m.insert.call(sim, record) # Check data integrity for expected, record in zip(self.check_list, self.m._dut.data): - assert expected == (yield from get_outputs(record)) + assert expected == data_const_to_dict(sim.get(record)) # Take first instruction - assert (yield self.m._dut.get_ready_list[0].ready) == 1 - data = yield from self.m.take.call(rs_entry_id=0) + assert sim.get(self.m._dut.get_ready_list[0].ready) == 1 + data = data_const_to_dict(await self.m.take.call(sim, rs_entry_id=0)) for key in data: assert data[key] == self.check_list[0]["rs_data"][key] - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 0 + assert sim.get(self.m._dut.get_ready_list[0].ready) == 0 # Update second instuction and take it reg_id = 2 value_spx = 1 - yield from self.m.update.call(reg_id=reg_id, reg_val=value_spx) - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 1 - data = yield from self.m.take.call(rs_entry_id=1) + await self.m.update.call(sim, reg_id=reg_id, reg_val=value_spx) + assert sim.get(self.m._dut.get_ready_list[0].ready) == 1 + data = data_const_to_dict(await self.m.take.call(sim, rs_entry_id=1)) for key in data: assert data[key] == self.check_list[1]["rs_data"][key] - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 0 + assert sim.get(self.m._dut.get_ready_list[0].ready) == 0 # Insert two new ready instructions and take them reg_id = 0 @@ -413,7 +393,7 @@ def simulation_process(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": 0, "s2_val": 0, @@ -422,22 +402,19 @@ def simulation_process(self): } for index in range(2): - yield from self.m.insert.call(rs_entry_id=index, rs_data=entry_data) - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 1 - assert (yield self.m._dut.data_ready[index]) == 1 + await self.m.insert.call(sim, rs_entry_id=index, rs_data=entry_data) + assert sim.get(self.m._dut.get_ready_list[0].ready) == 1 + assert sim.get(self.m._dut.data_ready[index]) == 1 - data = yield from self.m.take.call(rs_entry_id=0) + data = data_const_to_dict(await self.m.take.call(sim, rs_entry_id=0)) for key in data: assert data[key] == entry_data[key] - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 1 + assert sim.get(self.m._dut.get_ready_list[0].ready) == 1 - data = yield from self.m.take.call(rs_entry_id=1) + data = data_const_to_dict(await self.m.take.call(sim, rs_entry_id=1)) for key in data: assert data[key] == entry_data[key] - yield Settle() - assert (yield self.m._dut.get_ready_list[0].ready) == 0 + assert sim.get(self.m._dut.get_ready_list[0].ready) == 0 class TestRSMethodGetReadyList(TestCaseWithSimulator): @@ -456,7 +433,7 @@ def test_get_ready_list(self): "exec_fn": { "op_type": 1, "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -469,28 +446,25 @@ def test_get_ready_list(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # After each insert, entry should be marked as full for record in self.insert_list: - yield from self.m.insert.call(record) - yield Settle() + await self.m.insert.call(sim, record) # Check ready vector integrity - ready_list = (yield from self.m.get_ready_list[0].call())["ready_list"] + ready_list = (await self.m.get_ready_list[0].call(sim)).ready_list assert ready_list == 0b0011 # Take first record and check ready vector integrity - yield from self.m.take.call(rs_entry_id=0) - yield Settle() - ready_list = (yield from self.m.get_ready_list[0].call())["ready_list"] + await self.m.take.call(sim, rs_entry_id=0) + ready_list = (await self.m.get_ready_list[0].call(sim)).ready_list assert ready_list == 0b0010 # Take second record and check ready vector integrity - yield from self.m.take.call(rs_entry_id=1) - yield Settle() - option_ready_list = yield from self.m.get_ready_list[0].call_try() + await self.m.take.call(sim, rs_entry_id=1) + option_ready_list = await self.m.get_ready_list[0].call_try(sim) assert option_ready_list is None @@ -500,7 +474,7 @@ def test_two_get_ready_lists(self): self.rs_entries = self.gen_params.max_rs_entries self.rs_entries_bits = self.gen_params.max_rs_entries_bits self.m = SimpleTestCircuit( - RS(self.gen_params, 2**self.rs_entries_bits, 0, [[OpType(1), OpType(2)], [OpType(3), OpType(4)]]) + RS(self.gen_params, 2**self.rs_entries_bits, 0, [[OpType(1), OpType(2)], [OpType(3), OpType(4)]]), ) self.insert_list = [ { @@ -513,7 +487,7 @@ def test_two_get_ready_lists(self): "exec_fn": { "op_type": OpType(id + 1), "funct3": 2, - "funct7": 3, + "funct7": 4, }, "s1_val": id, "s2_val": id, @@ -525,29 +499,27 @@ def test_two_get_ready_lists(self): self.check_list = create_check_list(self.rs_entries_bits, self.insert_list) with self.run_simulation(self.m) as sim: - sim.add_process(self.simulation_process) + sim.add_testbench(self.simulation_process) - def simulation_process(self): + async def simulation_process(self, sim: TestbenchContext): # After each insert, entry should be marked as full for record in self.insert_list: - yield from self.m.insert.call(record) - yield Settle() + await self.m.insert.call(sim, record) masks = [0b0011, 0b1100] for i in range(self.m._dut.rs_entries + 1): # Check ready vectors' integrity for j in range(2): - ready_list = yield from self.m.get_ready_list[j].call_try() + ready_list = await self.m.get_ready_list[j].call_try(sim) if masks[j]: - assert ready_list == {"ready_list": masks[j]} + assert ready_list.ready_list == masks[j] else: assert ready_list is None # Take a record if i == self.m._dut.rs_entries: break - yield from self.m.take.call(rs_entry_id=i) - yield Settle() + await self.m.take.call(sim, rs_entry_id=i) masks = [mask & ~(1 << i) for mask in masks] diff --git a/test/func_blocks/fu/fpu/test_fpu_error.py b/test/func_blocks/fu/fpu/test_fpu_error.py new file mode 100644 index 000000000..451ce5242 --- /dev/null +++ b/test/func_blocks/fu/fpu/test_fpu_error.py @@ -0,0 +1,305 @@ +from coreblocks.func_blocks.fu.fpu.fpu_error_module import * +from coreblocks.func_blocks.fu.fpu.fpu_common import ( + RoundingModes, + FPUParams, + Errors, +) +from transactron import TModule +from transactron.lib import AdapterTrans +from parameterized import parameterized +from transactron.testing import * +from amaranth import * + + +class TestFPUError(TestCaseWithSimulator): + class FPUErrorModule(Elaboratable): + def __init__(self, params: FPUParams): + self.params = params + + def elaborate(self, platform): + m = TModule() + m.submodules.fpue = fpue = self.fpu_error_module = FPUErrorModule(fpu_params=self.params) + m.submodules.error_checking = self.error_checking_request_adapter = TestbenchIO( + AdapterTrans(fpue.error_checking_request) + ) + return m + + class HelpValues: + def __init__(self, params: FPUParams): + self.params = params + self.max_exp = (2**self.params.exp_width) - 1 + self.max_norm_exp = (2**self.params.exp_width) - 2 + self.not_max_norm_exp = (2**self.params.exp_width) - 3 + self.max_sig = (2**params.sig_width) - 1 + self.not_max_norm_sig = 1 << (self.params.sig_width - 1) | 1 + self.not_max_norm_even_sig = 1 << (self.params.sig_width - 1) + self.sub_norm_sig = 3 + self.min_norm_sig = 1 << (self.params.sig_width - 1) + self.max_sub_norm_sig = (2 ** (self.params.sig_width - 1)) - 1 + self.qnan = 3 << (self.params.sig_width - 2) | 1 + + params = FPUParams(sig_width=24, exp_width=8) + help_values = HelpValues(params) + + @parameterized.expand([(params, help_values)]) + def test_special_cases(self, params: FPUParams, help_values: HelpValues): + fpue = TestFPUError.FPUErrorModule(params) + + async def other_cases_test(sim: TestbenchContext): + test_cases = [ + # No errors + { + "sign": 0, + "sig": help_values.not_max_norm_even_sig, + "exp": help_values.not_max_norm_exp, + "inexact": 0, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "invalid_operation": 0, + "division_by_zero": 0, + "input_inf": 0, + }, + # inexact + { + "sign": 0, + "sig": help_values.not_max_norm_even_sig, + "exp": help_values.not_max_norm_exp, + "inexact": 1, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "invalid_operation": 0, + "division_by_zero": 0, + "input_inf": 0, + }, + # underflow + { + "sign": 0, + "sig": help_values.sub_norm_sig, + "exp": 0, + "inexact": 1, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "invalid_operation": 0, + "division_by_zero": 0, + "input_inf": 0, + }, + # invalid operation + { + "sign": 0, + "sig": help_values.qnan, + "exp": help_values.max_exp, + "inexact": 1, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "invalid_operation": 1, + "division_by_zero": 0, + "input_inf": 0, + }, + # division by zero + { + "sign": 0, + "sig": 0, + "exp": help_values.max_exp, + "inexact": 1, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "invalid_operation": 0, + "division_by_zero": 1, + "input_inf": 0, + }, + # overflow but no round and sticky bits + { + "sign": 0, + "sig": 0, + "exp": help_values.max_exp, + "inexact": 0, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "invalid_operation": 0, + "division_by_zero": 0, + "input_inf": 0, + }, + # tininess but no underflow + { + "sign": 0, + "sig": help_values.sub_norm_sig, + "exp": 0, + "inexact": 0, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "invalid_operation": 0, + "division_by_zero": 0, + "input_inf": 0, + }, + # one of inputs was qnan + { + "sign": 0, + "sig": help_values.qnan, + "exp": help_values.max_exp, + "inexact": 1, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "invalid_operation": 0, + "division_by_zero": 0, + "input_inf": 0, + }, + # one of inputs was inf + { + "sign": 1, + "sig": 0, + "exp": help_values.max_exp, + "inexact": 1, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "invalid_operation": 0, + "division_by_zero": 0, + "input_inf": 1, + }, + # subnormal number become normalized after rounding + { + "sign": 1, + "sig": help_values.min_norm_sig, + "exp": 0, + "inexact": 1, + "rounding_mode": RoundingModes.ROUND_NEAREST_AWAY, + "invalid_operation": 0, + "division_by_zero": 0, + "input_inf": 0, + }, + ] + + expected_results = [ + # No errors + {"sign": 0, "sig": help_values.not_max_norm_even_sig, "exp": help_values.not_max_norm_exp, "errors": 0}, + # inexact + { + "sign": 0, + "sig": help_values.not_max_norm_even_sig, + "exp": help_values.not_max_norm_exp, + "errors": Errors.INEXACT, + }, + # underflow + {"sign": 0, "sig": help_values.sub_norm_sig, "exp": 0, "errors": Errors.UNDERFLOW | Errors.INEXACT}, + # invalid operation + {"sign": 0, "sig": help_values.qnan, "exp": help_values.max_exp, "errors": Errors.INVALID_OPERATION}, + # division by zero + {"sign": 0, "sig": 0, "exp": help_values.max_exp, "errors": Errors.DIVISION_BY_ZERO}, + # overflow but no round and sticky bits + {"sign": 0, "sig": 0, "exp": help_values.max_exp, "errors": Errors.INEXACT | Errors.OVERFLOW}, + # tininess but no underflow + {"sign": 0, "sig": help_values.sub_norm_sig, "exp": 0, "errors": 0}, + # one of inputs was qnan + {"sign": 0, "sig": help_values.qnan, "exp": help_values.max_exp, "errors": 0}, + # one of inputs was inf + {"sign": 1, "sig": 0, "exp": help_values.max_exp, "errors": 0}, + # subnormal number become normalized after rounding + {"sign": 1, "sig": help_values.min_norm_sig, "exp": 1, "errors": Errors.INEXACT}, + ] + for i in range(len(test_cases)): + + resp = await fpue.error_checking_request_adapter.call(sim, test_cases[i]) + assert resp.sign == expected_results[i]["sign"] + assert resp.exp == expected_results[i]["exp"] + assert resp.sig == expected_results[i]["sig"] + assert resp.errors == expected_results[i]["errors"] + + async def test_process(sim: TestbenchContext): + await other_cases_test(sim) + + with self.run_simulation(fpue) as sim: + sim.add_testbench(test_process) + + @parameterized.expand( + [ + ( + params, + help_values, + RoundingModes.ROUND_NEAREST_EVEN, + 0, + help_values.max_exp, + 0, + help_values.max_exp, + ), + ( + params, + help_values, + RoundingModes.ROUND_NEAREST_AWAY, + 0, + help_values.max_exp, + 0, + help_values.max_exp, + ), + ( + params, + help_values, + RoundingModes.ROUND_UP, + 0, + help_values.max_exp, + help_values.max_sig, + help_values.max_norm_exp, + ), + ( + params, + help_values, + RoundingModes.ROUND_DOWN, + help_values.max_sig, + help_values.max_norm_exp, + 0, + help_values.max_exp, + ), + ( + params, + help_values, + RoundingModes.ROUND_ZERO, + help_values.max_sig, + help_values.max_norm_exp, + help_values.max_sig, + help_values.max_norm_exp, + ), + ] + ) + def test_rounding( + self, + params: FPUParams, + help_values: HelpValues, + rm: RoundingModes, + plus_overflow_sig: int, + plus_overflow_exp: int, + minus_overflow_sig: int, + minus_overflow_exp: int, + ): + fpue = TestFPUError.FPUErrorModule(params) + + async def one_rounding_mode_test(sim: TestbenchContext): + test_cases = [ + # overflow detection + { + "sign": 0, + "sig": 0, + "exp": help_values.max_exp, + "rounding_mode": rm, + "inexact": 0, + "invalid_operation": 0, + "division_by_zero": 0, + "input_inf": 0, + }, + { + "sign": 1, + "sig": 0, + "exp": help_values.max_exp, + "rounding_mode": rm, + "inexact": 0, + "invalid_operation": 0, + "division_by_zero": 0, + "input_inf": 0, + }, + ] + expected_results = [ + # overflow detection + {"sign": 0, "sig": plus_overflow_sig, "exp": plus_overflow_exp, "errors": 20}, + {"sign": 1, "sig": minus_overflow_sig, "exp": minus_overflow_exp, "errors": 20}, + ] + + for i in range(len(test_cases)): + resp = await fpue.error_checking_request_adapter.call(sim, test_cases[i]) + assert resp["sign"] == expected_results[i]["sign"] + assert resp["exp"] == expected_results[i]["exp"] + assert resp["sig"] == expected_results[i]["sig"] + assert resp["errors"] == expected_results[i]["errors"] + + async def test_process(sim: TestbenchContext): + await one_rounding_mode_test(sim) + + with self.run_simulation(fpue) as sim: + sim.add_testbench(test_process) diff --git a/test/func_blocks/fu/fpu/test_fpu_rounding.py b/test/func_blocks/fu/fpu/test_fpu_rounding.py new file mode 100644 index 000000000..d33849c06 --- /dev/null +++ b/test/func_blocks/fu/fpu/test_fpu_rounding.py @@ -0,0 +1,276 @@ +from coreblocks.func_blocks.fu.fpu.fpu_rounding_module import * +from coreblocks.func_blocks.fu.fpu.fpu_common import ( + RoundingModes, + FPUParams, +) +from transactron import TModule +from transactron.lib import AdapterTrans +from parameterized import parameterized +from transactron.testing import * +from amaranth import * + + +class TestFPURounding(TestCaseWithSimulator): + class FPURoundingModule(Elaboratable): + def __init__(self, params: FPUParams): + self.params = params + + def elaborate(self, platform): + m = TModule() + m.submodules.fpur = fpur = self.fpu_rounding = FPURounding(fpu_params=self.params) + m.submodules.rounding = self.rounding_request_adapter = TestbenchIO(AdapterTrans(fpur.rounding_request)) + return m + + class HelpValues: + def __init__(self, params: FPUParams): + self.params = params + self.max_exp = (2**self.params.exp_width) - 1 + self.max_norm_exp = (2**self.params.exp_width) - 2 + self.not_max_norm_exp = (2**self.params.exp_width) - 3 + self.max_sig = (2**params.sig_width) - 1 + self.not_max_norm_sig = 1 << (self.params.sig_width - 1) | 1 + self.not_max_norm_even_sig = 1 << (self.params.sig_width - 1) + self.sub_norm_sig = 3 + self.max_sub_norm_sig = (2 ** (self.params.sig_width - 1)) - 1 + self.qnan = 3 << (self.params.sig_width - 2) | 1 + + params = FPUParams(sig_width=24, exp_width=8) + help_values = HelpValues(params) + + tie_to_even_inc_array = [ + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + ] + tie_to_away_inc_array = [0, 1, 0, 1, 0, 1, 0, 1] + round_up_inc_array = [0, 1, 1, 1, 0, 0, 0, 0] + round_down_inc_array = [0, 0, 0, 0, 0, 1, 1, 1] + round_zero_inc_array = [0, 0, 0, 0, 0, 0, 0, 0] + + @parameterized.expand( + [ + ( + params, + help_values, + RoundingModes.ROUND_NEAREST_EVEN, + tie_to_away_inc_array, + ), + ( + params, + help_values, + RoundingModes.ROUND_NEAREST_AWAY, + tie_to_away_inc_array, + ), + ( + params, + help_values, + RoundingModes.ROUND_UP, + round_up_inc_array, + ), + ( + params, + help_values, + RoundingModes.ROUND_DOWN, + round_down_inc_array, + ), + ( + params, + help_values, + RoundingModes.ROUND_ZERO, + round_zero_inc_array, + ), + ] + ) + def test_rounding( + self, + params: FPUParams, + help_values: HelpValues, + rm: RoundingModes, + inc_arr: list, + ): + fpurt = TestFPURounding.FPURoundingModule(params) + + async def one_rounding_mode_test(sim: TestbenchContext): + test_cases = [ + # carry after increment + { + "sign": 0 if rm != RoundingModes.ROUND_DOWN else 1, + "sig": help_values.max_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 1, + "sticky_bit": 1, + "rounding_mode": rm, + }, + # no overflow 00 + { + "sign": 0, + "sig": help_values.not_max_norm_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 0, + "sticky_bit": 0, + "rounding_mode": rm, + }, + { + "sign": 1, + "sig": help_values.not_max_norm_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 0, + "sticky_bit": 0, + "rounding_mode": rm, + }, + # no overflow 10 + { + "sign": 0, + "sig": help_values.not_max_norm_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 1, + "sticky_bit": 0, + "rounding_mode": rm, + }, + { + "sign": 1, + "sig": help_values.not_max_norm_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 1, + "sticky_bit": 0, + "rounding_mode": rm, + }, + # no overflow 01 + { + "sign": 0, + "sig": help_values.not_max_norm_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 0, + "sticky_bit": 1, + "rounding_mode": rm, + }, + { + "sign": 1, + "sig": help_values.not_max_norm_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 0, + "sticky_bit": 1, + "rounding_mode": rm, + }, + # no overflow 11 + { + "sign": 0, + "sig": help_values.not_max_norm_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 1, + "sticky_bit": 1, + "rounding_mode": rm, + }, + { + "sign": 1, + "sig": help_values.not_max_norm_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 1, + "sticky_bit": 1, + "rounding_mode": rm, + }, + # Round to nearest tie to even + { + "sign": 1, + "sig": help_values.not_max_norm_even_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 1, + "sticky_bit": 0, + "rounding_mode": rm, + }, + { + "sign": 0, + "sig": help_values.not_max_norm_even_sig, + "exp": help_values.not_max_norm_exp, + "round_bit": 1, + "sticky_bit": 0, + "rounding_mode": rm, + }, + ] + expected_results = [ + # carry after increment + { + "sig": (help_values.max_sig + 1) >> 1 if rm != RoundingModes.ROUND_ZERO else help_values.max_sig, + "exp": ( + help_values.not_max_norm_exp + 1 + if rm != RoundingModes.ROUND_ZERO + else help_values.not_max_norm_exp + ), + "inexact": 1, + }, + # no overflow 00 + { + "sig": help_values.not_max_norm_sig + inc_arr[0], + "exp": help_values.not_max_norm_exp, + "inexact": 0, + }, + { + "sig": help_values.not_max_norm_sig + inc_arr[4], + "exp": help_values.not_max_norm_exp, + "inexact": 0, + }, + # no overflow 01 + { + "sig": help_values.not_max_norm_sig + inc_arr[1], + "exp": help_values.not_max_norm_exp, + "inexact": 1, + }, + { + "sig": help_values.not_max_norm_sig + inc_arr[5], + "exp": help_values.not_max_norm_exp, + "inexact": 1, + }, + # no overflow 10 + { + "sig": help_values.not_max_norm_sig + inc_arr[2], + "exp": help_values.not_max_norm_exp, + "inexact": 1, + }, + { + "sig": help_values.not_max_norm_sig + inc_arr[6], + "exp": help_values.not_max_norm_exp, + "inexact": 1, + }, + # no overflow 11 + { + "sig": help_values.not_max_norm_sig + inc_arr[3], + "exp": help_values.not_max_norm_exp, + "inexact": 1, + }, + { + "sig": help_values.not_max_norm_sig + inc_arr[7], + "exp": help_values.not_max_norm_exp, + "inexact": 1, + }, + # Round to nearest tie to even + { + "sig": help_values.not_max_norm_even_sig, + "exp": help_values.not_max_norm_exp, + "inexact": 1, + }, + { + "sig": help_values.not_max_norm_even_sig, + "exp": help_values.not_max_norm_exp, + "inexact": 1, + }, + ] + + num_of_test_cases = len(test_cases) if rm == RoundingModes.ROUND_NEAREST_EVEN else len(test_cases) - 2 + + for i in range(num_of_test_cases): + + resp = await fpurt.rounding_request_adapter.call(sim, test_cases[i]) + assert resp.exp == expected_results[i]["exp"] + assert resp.sig == expected_results[i]["sig"] + assert resp.inexact == expected_results[i]["inexact"] + + async def test_process(sim: TestbenchContext): + await one_rounding_mode_test(sim) + + with self.run_simulation(fpurt) as sim: + sim.add_testbench(test_process) diff --git a/test/func_blocks/fu/functional_common.py b/test/func_blocks/fu/functional_common.py index 088c4337d..85d34ab1d 100644 --- a/test/func_blocks/fu/functional_common.py +++ b/test/func_blocks/fu/functional_common.py @@ -6,11 +6,11 @@ from typing import Generic, TypeVar from amaranth import Elaboratable, Signal -from amaranth.sim import Passive, Tick from coreblocks.params import GenParams from coreblocks.params.configurations import test_core_config from coreblocks.priv.csr.csr_instances import GenericCSRRegisters +from transactron.testing.functions import data_const_to_dict from transactron.utils.dependencies import DependencyContext from coreblocks.params.fu_params import FunctionalComponentParams from coreblocks.arch import Funct3, Funct7 @@ -18,7 +18,14 @@ from coreblocks.interface.layouts import ExceptionRegisterLayouts from coreblocks.arch.optypes import OpType from transactron.lib import Adapter -from transactron.testing import RecordIntDict, RecordIntDictRet, TestbenchIO, TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import ( + RecordIntDict, + TestbenchIO, + TestCaseWithSimulator, + SimpleTestCircuit, + ProcessContext, + TestbenchContext, +) from transactron.utils import ModuleConnector @@ -111,8 +118,8 @@ def setup(self, fixture_initialize_testing_env): random.seed(self.seed) self.requests = deque[RecordIntDict]() - self.responses = deque[RecordIntDictRet]() - self.exceptions = deque[RecordIntDictRet]() + self.responses = deque[RecordIntDict]() + self.exceptions = deque[RecordIntDict]() max_int = 2**self.gen_params.isa.xlen - 1 functions = list(self.ops.keys()) @@ -158,37 +165,38 @@ def setup(self, fixture_initialize_testing_env): self.responses.append({"rob_id": rob_id, "rp_dst": rp_dst, "exception": int(cause is not None)} | results) - def consumer(self): + async def consumer(self, sim: TestbenchContext): while self.responses: expected = self.responses.pop() - result = yield from self.m.accept.call() - assert expected == result - yield from self.random_wait(self.max_wait) + result = await self.m.accept.call(sim) + assert expected == data_const_to_dict(result) + await self.random_wait(sim, self.max_wait) - def producer(self): + async def producer(self, sim: TestbenchContext): while self.requests: req = self.requests.pop() - yield from self.m.issue.call(req) - yield from self.random_wait(self.max_wait) - - def exception_consumer(self): - while self.exceptions: - expected = self.exceptions.pop() - result = yield from self.report_mock.call() - assert expected == result - yield from self.random_wait(self.max_wait) + await self.m.issue.call(sim, req) + await self.random_wait(sim, self.max_wait) + + async def exception_consumer(self, sim: TestbenchContext): + # This is a background testbench so that extra calls can be detected reliably + with sim.critical(): + while self.exceptions: + expected = self.exceptions.pop() + result = await self.report_mock.call(sim) + assert expected == data_const_to_dict(result) + await self.random_wait(sim, self.max_wait) # keep partialy dependent tests from hanging up and detect extra calls - yield Passive() - result = yield from self.report_mock.call() + result = await self.report_mock.call(sim) assert not True, "unexpected report call" - def pipeline_verifier(self): - yield Passive() - while True: - assert (yield self.m.issue.adapter.iface.ready) - assert (yield self.m.issue.adapter.en) == (yield self.m.issue.adapter.done) - yield Tick() + async def pipeline_verifier(self, sim: ProcessContext): + async for *_, ready, en, done in sim.tick().sample( + self.m.issue.adapter.iface.ready, self.m.issue.adapter.en, self.m.issue.adapter.done + ): + assert ready + assert en == done def run_standard_fu_test(self, pipeline_test=False): if pipeline_test: @@ -197,8 +205,8 @@ def run_standard_fu_test(self, pipeline_test=False): self.max_wait = 10 with self.run_simulation(self.circ) as sim: - sim.add_process(self.producer) - sim.add_process(self.consumer) - sim.add_process(self.exception_consumer) + sim.add_testbench(self.producer) + sim.add_testbench(self.consumer) + sim.add_testbench(self.exception_consumer, background=True) if pipeline_test: sim.add_process(self.pipeline_verifier) diff --git a/test/func_blocks/fu/test_fu_decoder.py b/test/func_blocks/fu/test_fu_decoder.py index 9c6601f3c..cedaf93b1 100644 --- a/test/func_blocks/fu/test_fu_decoder.py +++ b/test/func_blocks/fu/test_fu_decoder.py @@ -1,9 +1,7 @@ import random -from typing import Sequence, Generator -from amaranth import * -from amaranth.sim import * +from collections.abc import Sequence -from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator +from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchContext from coreblocks.func_blocks.fu.common.fu_decoder import DecoderManager, Decoder from coreblocks.arch import OpType, Funct3, Funct7 @@ -31,21 +29,19 @@ def expected_results(self, instructions: Sequence[tuple], op_type_dependent: boo return acc - def handle_signals(self, decoder: Decoder, exec_fn: dict[str, int]) -> Generator: - yield decoder.exec_fn.op_type.eq(exec_fn["op_type"]) - yield decoder.exec_fn.funct3.eq(exec_fn["funct3"]) - yield decoder.exec_fn.funct7.eq(exec_fn["funct7"]) + async def handle_signals(self, sim: TestbenchContext, decoder: Decoder, exec_fn: dict[str, int]): + sim.set(decoder.exec_fn.op_type, exec_fn["op_type"]) + sim.set(decoder.exec_fn.funct3, exec_fn["funct3"]) + sim.set(decoder.exec_fn.funct7, exec_fn["funct7"]) - yield Settle() - - return (yield decoder.decode_fn) + return sim.get(decoder.decode_fn) def run_test_case(self, decoder_manager: DecoderManager, test_inputs: Sequence[tuple]) -> None: instructions = decoder_manager.get_instructions() decoder = decoder_manager.get_decoder(self.gen_params) op_type_dependent = len(decoder_manager.get_op_types()) != 1 - def process(): + async def process(sim: TestbenchContext): for test_input in test_inputs: exec_fn = { "op_type": test_input[1], @@ -53,7 +49,7 @@ def process(): "funct7": test_input[3] if len(test_input) >= 4 else 0, } - returned = yield from self.handle_signals(decoder, exec_fn) + returned = await self.handle_signals(sim, decoder, exec_fn) expected = self.expected_results(instructions, op_type_dependent, exec_fn) assert returned == expected @@ -61,7 +57,7 @@ def process(): test_circuit = SimpleTestCircuit(decoder) with self.run_simulation(test_circuit) as sim: - sim.add_process(process) + sim.add_testbench(process) def generate_random_instructions(self) -> Sequence[tuple]: random.seed(42) diff --git a/test/func_blocks/fu/test_pipelined_mul_unit.py b/test/func_blocks/fu/test_pipelined_mul_unit.py index 1c955c6b4..20b46ff14 100644 --- a/test/func_blocks/fu/test_pipelined_mul_unit.py +++ b/test/func_blocks/fu/test_pipelined_mul_unit.py @@ -2,15 +2,15 @@ import math from collections import deque -from amaranth.sim import Settle from parameterized import parameterized_class from coreblocks.func_blocks.fu.unsigned_multiplication.pipelined import PipelinedUnsignedMul -from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit, TestbenchContext from coreblocks.params import GenParams from coreblocks.params.configurations import test_core_config +from transactron.testing.functions import data_const_to_dict @parameterized_class( @@ -57,14 +57,14 @@ def setup_method(self): ) def test_pipeline(self): - def consumer(): + async def consumer(sim: TestbenchContext): time = 0 while self.responses: - res = yield from self.m.accept.call_try() + res = await self.m.accept.call_try(sim) time += 1 if res is not None: expected = self.responses.pop() - assert expected == res + assert expected == data_const_to_dict(res) assert ( time @@ -73,12 +73,11 @@ def consumer(): + 2 ) - def producer(): + async def producer(sim: TestbenchContext): while self.requests: req = self.requests.pop() - yield Settle() - yield from self.m.issue.call(req) + await self.m.issue.call(sim, req) with self.run_simulation(self.m) as sim: - sim.add_process(producer) - sim.add_process(consumer) + sim.add_testbench(producer) + sim.add_testbench(consumer) diff --git a/test/func_blocks/fu/test_unsigned_mul_unit.py b/test/func_blocks/fu/test_unsigned_mul_unit.py index 06321672c..bb522c73c 100644 --- a/test/func_blocks/fu/test_unsigned_mul_unit.py +++ b/test/func_blocks/fu/test_unsigned_mul_unit.py @@ -1,8 +1,6 @@ import random from collections import deque -from typing import Type -from amaranth.sim import Settle from parameterized import parameterized_class from coreblocks.func_blocks.fu.unsigned_multiplication.common import MulBaseUnsigned @@ -11,10 +9,11 @@ from coreblocks.func_blocks.fu.unsigned_multiplication.shift import ShiftUnsignedMul from coreblocks.func_blocks.fu.unsigned_multiplication.pipelined import PipelinedUnsignedMul -from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit +from transactron.testing import TestCaseWithSimulator, SimpleTestCircuit, TestbenchContext from coreblocks.params import GenParams from coreblocks.params.configurations import test_core_config +from transactron.testing.functions import data_const_to_dict @parameterized_class( @@ -39,7 +38,7 @@ ], ) class TestUnsignedMultiplicationUnit(TestCaseWithSimulator): - mul_unit: Type[MulBaseUnsigned] + mul_unit: type[MulBaseUnsigned] def setup_method(self): self.gen_params = GenParams(test_core_config) @@ -68,20 +67,19 @@ def setup_method(self): ) def test_pipeline(self): - def consumer(): + async def consumer(sim: TestbenchContext): while self.responses: expected = self.responses.pop() - result = yield from self.m.accept.call() - assert expected == result - yield from self.random_wait(self.waiting_time) + result = await self.m.accept.call(sim) + assert expected == data_const_to_dict(result) + await self.random_wait(sim, self.waiting_time) - def producer(): + async def producer(sim: TestbenchContext): while self.requests: req = self.requests.pop() - yield Settle() - yield from self.m.issue.call(req) - yield from self.random_wait(self.waiting_time) + await self.m.issue.call(sim, req) + await self.random_wait(sim, self.waiting_time) with self.run_simulation(self.m) as sim: - sim.add_process(producer) - sim.add_process(consumer) + sim.add_testbench(producer) + sim.add_testbench(consumer) diff --git a/test/func_blocks/lsu/test_dummylsu.py b/test/func_blocks/lsu/test_dummylsu.py index 976550a69..3a13149dc 100644 --- a/test/func_blocks/lsu/test_dummylsu.py +++ b/test/func_blocks/lsu/test_dummylsu.py @@ -1,19 +1,18 @@ import random from collections import deque -from amaranth.sim import Settle, Passive, Tick - from transactron.lib import Adapter +from transactron.testing.method_mock import MethodMock from transactron.utils import int_to_signed, signed_to_int from coreblocks.params import GenParams from coreblocks.func_blocks.fu.lsu.dummyLsu import LSUDummy from coreblocks.params.configurations import test_core_config from coreblocks.arch import * -from coreblocks.interface.keys import ExceptionReportKey, InstructionPrecommitKey +from coreblocks.interface.keys import CoreStateKey, ExceptionReportKey, InstructionPrecommitKey from transactron.utils.dependencies import DependencyContext from coreblocks.interface.layouts import ExceptionRegisterLayouts, RetirementLayouts from coreblocks.peripherals.wishbone import * -from transactron.testing import TestbenchIO, TestCaseWithSimulator, def_method_mock +from transactron.testing import TestbenchIO, TestCaseWithSimulator, def_method_mock, TestbenchContext from coreblocks.peripherals.bus_adapter import WishboneMasterAdapter from test.peripherals.test_wishbone import WishboneInterfaceWrapper @@ -84,11 +83,20 @@ def elaborate(self, platform): DependencyContext.get().add_dependency(ExceptionReportKey(), self.exception_report.adapter.iface) + layouts = self.gen.get(RetirementLayouts) m.submodules.precommit = self.precommit = TestbenchIO( - Adapter(o=self.gen.get(RetirementLayouts).precommit, nonexclusive=True) + Adapter( + i=layouts.precommit_in, + o=layouts.precommit_out, + nonexclusive=True, + combiner=lambda m, args, runs: args[0], + ).set(with_validate_arguments=True) ) DependencyContext.get().add_dependency(InstructionPrecommitKey(), self.precommit.adapter.iface) + m.submodules.core_state = self.core_state = TestbenchIO(Adapter(o=layouts.core_state, nonexclusive=True)) + DependencyContext.get().add_dependency(CoreStateKey(), self.core_state.adapter.iface) + m.submodules.func_unit = func_unit = LSUDummy(self.gen, self.bus_master_adapter) m.submodules.issue_mock = self.issue = TestbenchIO(AdapterTrans(func_unit.issue)) @@ -190,11 +198,8 @@ def setup_method(self) -> None: self.generate_instr(2**7, 2**7) self.max_wait = 10 - def wishbone_slave(self): - yield Passive() - - while True: - yield from self.test_module.io_in.slave_wait() + async def wishbone_slave(self, sim: TestbenchContext): + while self.mem_data_queue: generated_data = self.mem_data_queue.pop() if generated_data["misaligned"]: @@ -202,8 +207,8 @@ def wishbone_slave(self): mask = generated_data["mask"] sign = generated_data["sign"] - yield from self.test_module.io_in.slave_verify(generated_data["addr"], 0, 0, mask) - yield from self.random_wait(self.max_wait) + await self.test_module.io_in.slave_wait_and_verify(sim, generated_data["addr"], 0, 0, mask) + await self.random_wait(sim, self.max_wait) resp_data = int((generated_data["rnd_bytes"][:4]).hex(), 16) data_shift = (mask & -mask).bit_length() - 1 @@ -216,21 +221,20 @@ def wishbone_slave(self): data = int_to_signed(signed_to_int(data, size), 32) if not generated_data["err"]: self.returned_data.appendleft(data) - yield from self.test_module.io_in.slave_respond(resp_data, err=generated_data["err"]) - yield Settle() + await self.test_module.io_in.slave_respond(sim, resp_data, err=generated_data["err"]) - def inserter(self): + async def inserter(self, sim: TestbenchContext): for i in range(self.tests_number): req = self.instr_queue.pop() while req["rob_id"] not in self.free_rob_id: - yield Tick() + await sim.tick() self.free_rob_id.remove(req["rob_id"]) - yield from self.test_module.issue.call(req) - yield from self.random_wait(self.max_wait) + await self.test_module.issue.call(sim, req) + await self.random_wait(sim, self.max_wait) - def consumer(self): + async def consumer(self, sim: TestbenchContext): for i in range(self.tests_number): - v = yield from self.test_module.accept.call() + v = await self.test_module.accept.call(sim) rob_id = v["rob_id"] assert rob_id not in self.free_rob_id self.free_rob_id.add(rob_id) @@ -241,17 +245,27 @@ def consumer(self): assert v["result"] == self.returned_data.pop() assert v["exception"] == exc["err"] - yield from self.random_wait(self.max_wait) + await self.random_wait(sim, self.max_wait) def test(self): @def_method_mock(lambda: self.test_module.exception_report) def exception_consumer(arg): - assert arg == self.exception_queue.pop() + @MethodMock.effect + def eff(): + assert arg == self.exception_queue.pop() + + @def_method_mock(lambda: self.test_module.precommit, validate_arguments=lambda rob_id: True) + def precommiter(rob_id): + return {"side_fx": 1} + + @def_method_mock(lambda: self.test_module.core_state) + def core_state_process(): + return {"flushing": 0} with self.run_simulation(self.test_module) as sim: - sim.add_process(self.wishbone_slave) - sim.add_process(self.inserter) - sim.add_process(self.consumer) + sim.add_testbench(self.wishbone_slave, background=True) + sim.add_testbench(self.inserter) + sim.add_testbench(self.consumer) class TestDummyLSULoadsCycles(TestCaseWithSimulator): @@ -284,29 +298,33 @@ def setup_method(self) -> None: self.gen_params = GenParams(test_core_config.replace(phys_regs_bits=3, rob_entries_bits=3)) self.test_module = DummyLSUTestCircuit(self.gen_params) - def one_instr_test(self): + async def one_instr_test(self, sim: TestbenchContext): instr, wish_data = self.generate_instr(2**7, 2**7) - yield from self.test_module.issue.call(instr) - yield from self.test_module.io_in.slave_wait() + await self.test_module.issue.call(sim, instr) mask = wish_data["mask"] - yield from self.test_module.io_in.slave_verify(wish_data["addr"], 0, 0, mask) + await self.test_module.io_in.slave_wait_and_verify(sim, wish_data["addr"], 0, 0, mask) data = wish_data["rnd_bytes"][:4] data = int(data.hex(), 16) - yield from self.test_module.io_in.slave_respond(data) - yield Settle() + await self.test_module.io_in.slave_respond(sim, data) - v = yield from self.test_module.accept.call() + v = await self.test_module.accept.call(sim) assert v["result"] == data def test(self): @def_method_mock(lambda: self.test_module.exception_report) def exception_consumer(arg): - assert False + @MethodMock.effect + def eff(): + assert False + + @def_method_mock(lambda: self.test_module.precommit, validate_arguments=lambda rob_id: True) + def precommiter(rob_id): + return {"side_fx": 1} with self.run_simulation(self.test_module) as sim: - sim.add_process(self.one_instr_test) + sim.add_testbench(self.one_instr_test) class TestDummyLSUStores(TestCaseWithSimulator): @@ -360,9 +378,8 @@ def setup_method(self) -> None: self.generate_instr(2**7, 2**7) self.max_wait = 8 - def wishbone_slave(self): + async def wishbone_slave(self, sim: TestbenchContext): for i in range(self.tests_number): - yield from self.test_module.io_in.slave_wait() generated_data = self.mem_data_queue.pop() mask = generated_data["mask"] @@ -374,71 +391,69 @@ def wishbone_slave(self): data = (int(generated_data["data"][-2:].hex(), 16) & 0xFFFF) << h_dict[mask] else: data = int(generated_data["data"][-4:].hex(), 16) - yield from self.test_module.io_in.slave_verify(generated_data["addr"], data, 1, mask) - yield from self.random_wait(self.max_wait) + await self.test_module.io_in.slave_wait_and_verify(sim, generated_data["addr"], data, 1, mask) + await self.random_wait(sim, self.max_wait) - yield from self.test_module.io_in.slave_respond(0) - yield Settle() + await self.test_module.io_in.slave_respond(sim, 0) - def inserter(self): + async def inserter(self, sim: TestbenchContext): for i in range(self.tests_number): req = self.instr_queue.pop() self.get_result_data.appendleft(req["rob_id"]) - yield from self.test_module.issue.call(req) + await self.test_module.issue.call(sim, req) self.precommit_data.appendleft(req["rob_id"]) - yield from self.random_wait(self.max_wait) + await self.random_wait(sim, self.max_wait) - def get_resulter(self): + async def get_resulter(self, sim: TestbenchContext): for i in range(self.tests_number): - v = yield from self.test_module.accept.call() + v = await self.test_module.accept.call(sim) rob_id = self.get_result_data.pop() assert v["rob_id"] == rob_id assert v["rp_dst"] == 0 - yield from self.random_wait(self.max_wait) + await self.random_wait(sim, self.max_wait) self.precommit_data.pop() # retire - def precommiter(self): - yield Passive() - while True: - while len(self.precommit_data) == 0: - yield Tick() - rob_id = self.precommit_data[-1] # precommit is called continously until instruction is retired - yield from self.test_module.precommit.call(rob_id=rob_id, side_fx=1) + def precommit_validate(self, rob_id): + return len(self.precommit_data) > 0 and rob_id == self.precommit_data[-1] + + @def_method_mock(lambda self: self.test_module.precommit, validate_arguments=precommit_validate) + def precommiter(self, rob_id): + return {"side_fx": 1} def test(self): @def_method_mock(lambda: self.test_module.exception_report) def exception_consumer(arg): - assert False + @MethodMock.effect + def eff(): + assert False with self.run_simulation(self.test_module) as sim: - sim.add_process(self.wishbone_slave) - sim.add_process(self.inserter) - sim.add_process(self.get_resulter) - sim.add_process(self.precommiter) + sim.add_testbench(self.wishbone_slave) + sim.add_testbench(self.inserter) + sim.add_testbench(self.get_resulter) class TestDummyLSUFence(TestCaseWithSimulator): def get_instr(self, exec_fn): return {"rp_dst": 1, "rob_id": 1, "exec_fn": exec_fn, "s1_val": 4, "s2_val": 1, "imm": 8, "pc": 0} - def push_one_instr(self, instr): - yield from self.test_module.issue.call(instr) + async def push_one_instr(self, sim: TestbenchContext, instr): + await self.test_module.issue.call(sim, instr) if instr["exec_fn"]["op_type"] == OpType.LOAD: - yield from self.test_module.io_in.slave_wait() - yield from self.test_module.io_in.slave_respond(1) - yield Settle() - v = yield from self.test_module.accept.call() + await self.test_module.io_in.slave_wait(sim) + await self.test_module.io_in.slave_respond(sim, 1) + v = await self.test_module.accept.call(sim) if instr["exec_fn"]["op_type"] == OpType.LOAD: - assert v["result"] == 1 + assert v.result == 1 - def process(self): + async def process(self, sim: TestbenchContext): # just tests if FENCE doens't hang up the LSU load_fn = {"op_type": OpType.LOAD, "funct3": Funct3.W, "funct7": 0} fence_fn = {"op_type": OpType.FENCE, "funct3": 0, "funct7": 0} - yield from self.push_one_instr(self.get_instr(load_fn)) - yield from self.push_one_instr(self.get_instr(fence_fn)) - yield from self.push_one_instr(self.get_instr(load_fn)) + await self.push_one_instr(sim, self.get_instr(load_fn)) + await self.push_one_instr(sim, self.get_instr(fence_fn)) + await self.push_one_instr(sim, self.get_instr(load_fn)) def test_fence(self): self.gen_params = GenParams(test_core_config.replace(phys_regs_bits=3, rob_entries_bits=3)) @@ -446,7 +461,13 @@ def test_fence(self): @def_method_mock(lambda: self.test_module.exception_report) def exception_consumer(arg): - assert False + @MethodMock.effect + def eff(): + assert False + + @def_method_mock(lambda: self.test_module.precommit, validate_arguments=lambda rob_id: True) + def precommiter(rob_id): + return {"side_fx": 1} with self.run_simulation(self.test_module) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) diff --git a/test/func_blocks/lsu/test_pma.py b/test/func_blocks/lsu/test_pma.py index 81cacde33..16e8aec4b 100644 --- a/test/func_blocks/lsu/test_pma.py +++ b/test/func_blocks/lsu/test_pma.py @@ -1,4 +1,4 @@ -from amaranth.sim import Settle +import random from coreblocks.func_blocks.fu.lsu.pma import PMAChecker, PMARegion from transactron.lib import Adapter @@ -6,26 +6,26 @@ from coreblocks.func_blocks.fu.lsu.dummyLsu import LSUDummy from coreblocks.params.configurations import test_core_config from coreblocks.arch import * -from coreblocks.interface.keys import ExceptionReportKey, InstructionPrecommitKey +from coreblocks.interface.keys import CoreStateKey, ExceptionReportKey, InstructionPrecommitKey +from transactron.testing.method_mock import MethodMock from transactron.utils.dependencies import DependencyContext from coreblocks.interface.layouts import ExceptionRegisterLayouts, RetirementLayouts from coreblocks.peripherals.wishbone import * -from transactron.testing import TestbenchIO, TestCaseWithSimulator, def_method_mock +from transactron.testing import TestbenchIO, TestCaseWithSimulator, def_method_mock, TestbenchContext from coreblocks.peripherals.bus_adapter import WishboneMasterAdapter from test.peripherals.test_wishbone import WishboneInterfaceWrapper class TestPMADirect(TestCaseWithSimulator): - def verify_region(self, region: PMARegion): + async def verify_region(self, sim: TestbenchContext, region: PMARegion): for i in range(region.start, region.end + 1): - yield self.test_module.addr.eq(i) - yield Settle() - mmio = yield self.test_module.result["mmio"] + sim.set(self.test_module.addr, i) + mmio = sim.get(self.test_module.result.mmio) assert mmio == region.mmio - def process(self): + async def process(self, sim: TestbenchContext): for r in self.pma_regions: - yield from self.verify_region(r) + await self.verify_region(sim, r) def test_pma_direct(self): self.pma_regions = [ @@ -40,7 +40,7 @@ def test_pma_direct(self): self.test_module = PMAChecker(self.gen_params) with self.run_simulation(self.test_module) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) class PMAIndirectTestCircuit(Elaboratable): @@ -64,11 +64,20 @@ def elaborate(self, platform): DependencyContext.get().add_dependency(ExceptionReportKey(), self.exception_report.adapter.iface) + layouts = self.gen.get(RetirementLayouts) m.submodules.precommit = self.precommit = TestbenchIO( - Adapter(o=self.gen.get(RetirementLayouts).precommit, nonexclusive=True) + Adapter( + i=layouts.precommit_in, + o=layouts.precommit_out, + nonexclusive=True, + combiner=lambda m, args, runs: args[0], + ).set(with_validate_arguments=True) ) DependencyContext.get().add_dependency(InstructionPrecommitKey(), self.precommit.adapter.iface) + m.submodules.core_state = self.core_state = TestbenchIO(Adapter(o=layouts.core_state, nonexclusive=True)) + DependencyContext.get().add_dependency(CoreStateKey(), self.core_state.adapter.iface) + m.submodules.func_unit = func_unit = LSUDummy(self.gen, self.bus_master_adapter) m.submodules.issue_mock = self.issue = TestbenchIO(AdapterTrans(func_unit.issue)) @@ -91,27 +100,24 @@ def get_instr(self, addr): "pc": 0, } - def verify_region(self, region: PMARegion): + async def verify_region(self, sim: TestbenchContext, region: PMARegion): for addr in range(region.start, region.end + 1): instr = self.get_instr(addr) - yield from self.test_module.issue.call(instr) + await self.test_module.issue.call(sim, instr) if region.mmio is True: wb = self.test_module.io_in.wb for i in range(100): # 100 cycles is more than enough - wb_requested = (yield wb.stb) and (yield wb.cyc) + wb_requested = sim.get(wb.stb) and sim.get(wb.cyc) assert not wb_requested - yield from self.test_module.precommit.call(rob_id=1, side_fx=1) - - yield from self.test_module.io_in.slave_wait() - yield from self.test_module.io_in.slave_respond((addr << (addr % 4) * 8)) - yield Settle() - v = yield from self.test_module.accept.call() - assert v["result"] == addr + await self.test_module.io_in.slave_wait(sim) + await self.test_module.io_in.slave_respond(sim, (addr << (addr % 4) * 8)) + v = await self.test_module.accept.call(sim) + assert v.result == addr - def process(self): + async def process(self, sim: TestbenchContext): for region in self.pma_regions: - yield from self.verify_region(region) + await self.verify_region(sim, region) def test_pma_indirect(self): self.pma_regions = [ @@ -124,7 +130,21 @@ def test_pma_indirect(self): @def_method_mock(lambda: self.test_module.exception_report) def exception_consumer(arg): - assert False + @MethodMock.effect + def eff(): + assert False + + @def_method_mock( + lambda: self.test_module.precommit, + validate_arguments=lambda rob_id: rob_id == 1, + enable=lambda: random.random() < 0.5, + ) + def precommiter(rob_id): + return {"side_fx": 1} + + @def_method_mock(lambda: self.test_module.core_state) + def core_state_process(): + return {"flushing": 0} with self.run_simulation(self.test_module) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) diff --git a/test/params/test_configurations.py b/test/params/test_configurations.py index ec78cf9e8..a82cba1f8 100644 --- a/test/params/test_configurations.py +++ b/test/params/test_configurations.py @@ -18,9 +18,9 @@ class ISAStrTest: TEST_CASES = [ ISAStrTest( basic_core_config, - "rv32izicsr_zifencei_xintmachinemode", - "rv32izicsr_zifencei_xintmachinemode", - "rv32izicsr_zifencei_xintmachinemode", + "rv32imzicsr_zifencei_xintmachinemode", + "rv32imzicsr_zifencei_xintmachinemode", + "rv32imzicsr_zifencei_xintmachinemode", ), ISAStrTest( full_core_config, @@ -28,7 +28,7 @@ class ISAStrTest: "rv32imcbzicsr_zifencei_zicond_xintmachinemode", "rv32imcbzicsr_zifencei_zicond_xintmachinemode", ), - ISAStrTest(tiny_core_config, "rv32e", "rv32", "rv32e"), + ISAStrTest(tiny_core_config, "rv32e", "rv32e", "rv32e"), ISAStrTest(test_core_config, "rv32", "rv32", "rv32i"), ] diff --git a/test/peripherals/test_axi_lite.py b/test/peripherals/test_axi_lite.py index 514e85ea9..d1156271b 100644 --- a/test/peripherals/test_axi_lite.py +++ b/test/peripherals/test_axi_lite.py @@ -9,66 +9,77 @@ class AXILiteInterfaceWrapper: def __init__(self, axi_lite_master: AXILiteInterface): self.axi_lite = axi_lite_master - def slave_ra_ready(self, rdy=1): - yield self.axi_lite.read_address.rdy.eq(rdy) - - def slave_ra_wait(self): - while not (yield self.axi_lite.read_address.valid): - yield Tick() - - def slave_ra_verify(self, exp_addr, prot): - assert (yield self.axi_lite.read_address.valid) - assert (yield self.axi_lite.read_address.addr) == exp_addr - assert (yield self.axi_lite.read_address.prot) == prot - - def slave_rd_wait(self): - while not (yield self.axi_lite.read_data.rdy): - yield Tick() - - def slave_rd_respond(self, data, resp=0): - assert (yield self.axi_lite.read_data.rdy) - yield self.axi_lite.read_data.data.eq(data) - yield self.axi_lite.read_data.resp.eq(resp) - yield self.axi_lite.read_data.valid.eq(1) - yield Tick() - yield self.axi_lite.read_data.valid.eq(0) - - def slave_wa_ready(self, rdy=1): - yield self.axi_lite.write_address.rdy.eq(rdy) - - def slave_wa_wait(self): - while not (yield self.axi_lite.write_address.valid): - yield Tick() - - def slave_wa_verify(self, exp_addr, prot): - assert (yield self.axi_lite.write_address.valid) - assert (yield self.axi_lite.write_address.addr) == exp_addr - assert (yield self.axi_lite.write_address.prot) == prot - - def slave_wd_ready(self, rdy=1): - yield self.axi_lite.write_data.rdy.eq(rdy) - - def slave_wd_wait(self): - while not (yield self.axi_lite.write_data.valid): - yield Tick() - - def slave_wd_verify(self, exp_data, strb): - assert (yield self.axi_lite.write_data.valid) - assert (yield self.axi_lite.write_data.data) == exp_data - assert (yield self.axi_lite.write_data.strb) == strb - - def slave_wr_wait(self): - while not (yield self.axi_lite.write_response.rdy): - yield Tick() - - def slave_wr_respond(self, resp=0): - assert (yield self.axi_lite.write_response.rdy) - yield self.axi_lite.write_response.resp.eq(resp) - yield self.axi_lite.write_response.valid.eq(1) - yield Tick() - yield self.axi_lite.write_response.valid.eq(0) - - + def slave_ra_ready(self, sim: TestbenchContext, rdy=1): + sim.set(self.axi_lite.read_address.rdy, rdy) + + def slave_ra_get(self, sim: TestbenchContext): + ra = self.axi_lite.read_address + assert sim.get(ra.valid) + return sim.get(ra.addr), sim.get(ra.prot) + + def slave_ra_get_and_verify(self, sim: TestbenchContext, exp_addr: int, exp_prot: int): + addr, prot = self.slave_ra_get(sim) + assert addr == exp_addr + assert prot == exp_prot + + async def slave_rd_wait(self, sim: TestbenchContext): + rd = self.axi_lite.read_data + while not sim.get(rd.rdy): + await sim.tick() + + def slave_rd_get(self, sim: TestbenchContext): + rd = self.axi_lite.read_data + assert sim.get(rd.rdy) + + async def slave_rd_respond(self, sim: TestbenchContext, data, resp=0): + assert sim.get(self.axi_lite.read_data.rdy) + sim.set(self.axi_lite.read_data.data, data) + sim.set(self.axi_lite.read_data.resp, resp) + sim.set(self.axi_lite.read_data.valid, 1) + await sim.tick() + sim.set(self.axi_lite.read_data.valid, 0) + + def slave_wa_ready(self, sim: TestbenchContext, rdy=1): + sim.set(self.axi_lite.write_address.rdy, rdy) + + def slave_wa_get(self, sim: TestbenchContext): + wa = self.axi_lite.write_address + assert sim.get(wa.valid) + return sim.get(wa.addr), sim.get(wa.prot) + + def slave_wa_get_and_verify(self, sim: TestbenchContext, exp_addr, exp_prot): + addr, prot = self.slave_wa_get(sim) + assert addr == exp_addr + assert prot == exp_prot + + def slave_wd_ready(self, sim: TestbenchContext, rdy=1): + sim.set(self.axi_lite.write_data.rdy, rdy) + + def slave_wd_get(self, sim: TestbenchContext): + wd = self.axi_lite.write_data + assert sim.get(wd.valid) + return sim.get(wd.data), sim.get(wd.strb) + + def slave_wd_get_and_verify(self, sim: TestbenchContext, exp_data, exp_strb): + data, strb = self.slave_wd_get(sim) + assert data == exp_data + assert strb == exp_strb + + def slave_wr_get(self, sim: TestbenchContext): + wr = self.axi_lite.write_response + assert sim.get(wr.rdy) + + async def slave_wr_respond(self, sim: TestbenchContext, resp=0): + assert sim.get(self.axi_lite.write_response.rdy) + sim.set(self.axi_lite.write_response.resp, resp) + sim.set(self.axi_lite.write_response.valid, 1) + await sim.tick() + sim.set(self.axi_lite.write_response.valid, 0) + + +# TODO: this test needs a rewrite! +# 1. use queues instead of copy-pasting +# 2. handle each AXI pipe independently class TestAXILiteMaster(TestCaseWithSimulator): class AXILiteMasterTestModule(Elaboratable): def __init__(self, params: AXILiteParameters): @@ -103,161 +114,141 @@ def _(arg): def test_manual(self): almt = TestAXILiteMaster.AXILiteMasterTestModule(AXILiteParameters()) - def master_process(): + async def master_process(sim: TestbenchContext): # read request - yield from almt.read_address_request_adapter.call(addr=5, prot=0) + await almt.read_address_request_adapter.call(sim, addr=5, prot=0) - yield from almt.read_address_request_adapter.call(addr=10, prot=1) + await almt.read_address_request_adapter.call(sim, addr=10, prot=1) - yield from almt.read_address_request_adapter.call(addr=15, prot=1) + await almt.read_address_request_adapter.call(sim, addr=15, prot=1) - yield from almt.read_address_request_adapter.call(addr=20, prot=0) + await almt.read_address_request_adapter.call(sim, addr=20, prot=0) - yield from almt.write_request_adapter.call(addr=6, prot=0, data=10, strb=3) + await almt.write_request_adapter.call(sim, addr=6, prot=0, data=10, strb=3) - yield from almt.write_request_adapter.call(addr=7, prot=0, data=11, strb=3) + await almt.write_request_adapter.call(sim, addr=7, prot=0, data=11, strb=3) - yield from almt.write_request_adapter.call(addr=8, prot=0, data=12, strb=3) + await almt.write_request_adapter.call(sim, addr=8, prot=0, data=12, strb=3) - yield from almt.write_request_adapter.call(addr=9, prot=1, data=13, strb=4) + await almt.write_request_adapter.call(sim, addr=9, prot=1, data=13, strb=4) - yield from almt.read_address_request_adapter.call(addr=1, prot=1) + await almt.read_address_request_adapter.call(sim, addr=1, prot=1) - yield from almt.read_address_request_adapter.call(addr=2, prot=1) + await almt.read_address_request_adapter.call(sim, addr=2, prot=1) - def slave_process(): + async def slave_process(sim: TestbenchContext): slave = AXILiteInterfaceWrapper(almt.axi_lite_master.axil_master) # 1st request - yield from slave.slave_ra_ready(1) - yield from slave.slave_ra_wait() - yield from slave.slave_ra_verify(5, 0) - yield Settle() + slave.slave_ra_ready(sim, 1) + await sim.tick() + slave.slave_ra_get_and_verify(sim, 5, 0) # 2nd request and 1st respond - yield from slave.slave_ra_wait() - yield from slave.slave_rd_wait() - yield from slave.slave_ra_verify(10, 1) - yield from slave.slave_rd_respond(10, 0) - yield Settle() + await sim.tick() + slave.slave_ra_get_and_verify(sim, 10, 1) + slave.slave_rd_get(sim) + await slave.slave_rd_respond(sim, 10, 0) # 3rd request and 2nd respond - yield from slave.slave_ra_wait() - yield from slave.slave_rd_wait() - yield from slave.slave_ra_verify(15, 1) - yield from slave.slave_rd_respond(15, 0) - yield Settle() + slave.slave_ra_get_and_verify(sim, 15, 1) + slave.slave_rd_get(sim) + await slave.slave_rd_respond(sim, 15, 0) # 4th request and 3rd respond - yield from slave.slave_ra_wait() - yield from slave.slave_rd_wait() - yield from slave.slave_ra_verify(20, 0) - yield from slave.slave_rd_respond(20, 0) - yield Settle() + slave.slave_ra_get_and_verify(sim, 20, 0) + slave.slave_rd_get(sim) + await slave.slave_rd_respond(sim, 20, 0) # 4th respond and 1st write request - yield from slave.slave_ra_ready(0) - yield from slave.slave_wa_ready(1) - yield from slave.slave_wd_ready(1) - yield from slave.slave_rd_wait() - yield from slave.slave_wa_wait() - yield from slave.slave_wd_wait() - yield from slave.slave_wa_verify(6, 0) - yield from slave.slave_wd_verify(10, 3) - yield from slave.slave_rd_respond(25, 0) - yield Settle() + slave.slave_ra_ready(sim, 0) + slave.slave_wa_ready(sim, 1) + slave.slave_wd_ready(sim, 1) + slave.slave_rd_get(sim) + slave.slave_wa_get_and_verify(sim, 6, 0) + slave.slave_wd_get_and_verify(sim, 10, 3) + await slave.slave_rd_respond(sim, 25, 0) # 2nd write request and 1st respond - yield from slave.slave_wa_wait() - yield from slave.slave_wd_wait() - yield from slave.slave_wr_wait() - yield from slave.slave_wa_verify(7, 0) - yield from slave.slave_wd_verify(11, 3) - yield from slave.slave_wr_respond(1) - yield Settle() + slave.slave_wa_get_and_verify(sim, 7, 0) + slave.slave_wd_get_and_verify(sim, 11, 3) + slave.slave_wr_get(sim) + await slave.slave_wr_respond(sim, 1) # 3nd write request and 2st respond - yield from slave.slave_wa_wait() - yield from slave.slave_wd_wait() - yield from slave.slave_wr_wait() - yield from slave.slave_wa_verify(8, 0) - yield from slave.slave_wd_verify(12, 3) - yield from slave.slave_wr_respond(1) - yield Settle() + slave.slave_wa_get_and_verify(sim, 8, 0) + slave.slave_wd_get_and_verify(sim, 12, 3) + slave.slave_wr_get(sim) + await slave.slave_wr_respond(sim, 1) # 4th write request and 3rd respond - yield from slave.slave_wr_wait() - yield from slave.slave_wa_verify(9, 1) - yield from slave.slave_wd_verify(13, 4) - yield from slave.slave_wr_respond(1) - yield Settle() + slave.slave_wa_get_and_verify(sim, 9, 1) + slave.slave_wd_get_and_verify(sim, 13, 4) + slave.slave_wr_get(sim) + await slave.slave_wr_respond(sim, 1) # 4th respond - yield from slave.slave_wa_ready(0) - yield from slave.slave_wd_ready(0) - yield from slave.slave_wr_wait() - yield from slave.slave_wr_respond(0) - yield Settle() - - yield from slave.slave_ra_wait() - for _ in range(2): - yield Tick() - yield from slave.slave_ra_ready(1) - yield from slave.slave_ra_verify(1, 1) + slave.slave_wa_ready(sim, 0) + slave.slave_wd_ready(sim, 0) + slave.slave_wr_get(sim) + await slave.slave_wr_respond(sim, 0) + + slave.slave_ra_get(sim) + await self.tick(sim, 2) + slave.slave_ra_ready(sim, 1) + slave.slave_ra_get_and_verify(sim, 1, 1) # wait for next rising edge - yield Tick() - yield Tick() + await sim.tick() - yield from slave.slave_ra_wait() - yield from slave.slave_ra_verify(2, 1) - yield from slave.slave_rd_wait() - yield from slave.slave_rd_respond(3, 1) - yield Settle() + slave.slave_ra_get(sim) + slave.slave_ra_get_and_verify(sim, 2, 1) + slave.slave_rd_get(sim) + await slave.slave_rd_respond(sim, 3, 1) - yield from slave.slave_rd_wait() - yield from slave.slave_rd_respond(4, 1) + await slave.slave_rd_wait(sim) + await slave.slave_rd_respond(sim, 4, 1) - def result_process(): - resp = yield from almt.read_data_response_adapter.call() + async def result_process(sim: TestbenchContext): + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 10 assert resp["resp"] == 0 - resp = yield from almt.read_data_response_adapter.call() + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 15 assert resp["resp"] == 0 - resp = yield from almt.read_data_response_adapter.call() + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 20 assert resp["resp"] == 0 - resp = yield from almt.read_data_response_adapter.call() + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 25 assert resp["resp"] == 0 - resp = yield from almt.write_response_response_adapter.call() + resp = await almt.write_response_response_adapter.call(sim) assert resp["resp"] == 1 - resp = yield from almt.write_response_response_adapter.call() + resp = await almt.write_response_response_adapter.call(sim) assert resp["resp"] == 1 - resp = yield from almt.write_response_response_adapter.call() + resp = await almt.write_response_response_adapter.call(sim) assert resp["resp"] == 1 - resp = yield from almt.write_response_response_adapter.call() + resp = await almt.write_response_response_adapter.call(sim) assert resp["resp"] == 0 for _ in range(5): - yield Tick() + await sim.tick() - resp = yield from almt.read_data_response_adapter.call() + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 3 assert resp["resp"] == 1 - resp = yield from almt.read_data_response_adapter.call() + resp = await almt.read_data_response_adapter.call(sim) assert resp["data"] == 4 assert resp["resp"] == 1 with self.run_simulation(almt) as sim: - sim.add_process(master_process) - sim.add_process(slave_process) - sim.add_process(result_process) + sim.add_testbench(master_process) + sim.add_testbench(slave_process) + sim.add_testbench(result_process) diff --git a/test/peripherals/test_wishbone.py b/test/peripherals/test_wishbone.py index 68b04c6a0..d5a19f19b 100644 --- a/test/peripherals/test_wishbone.py +++ b/test/peripherals/test_wishbone.py @@ -1,7 +1,9 @@ +from collections.abc import Iterable import random from collections import deque from amaranth.lib.wiring import connect +from amaranth_types import ValueLike from coreblocks.peripherals.wishbone import * @@ -14,50 +16,80 @@ class WishboneInterfaceWrapper: def __init__(self, wishbone_interface: WishboneInterface): self.wb = wishbone_interface - def master_set(self, addr, data, we): - yield self.wb.dat_w.eq(data) - yield self.wb.adr.eq(addr) - yield self.wb.we.eq(we) - yield self.wb.cyc.eq(1) - yield self.wb.stb.eq(1) + def master_set(self, sim: SimulatorContext, addr: int, data: int, we: int): + sim.set(self.wb.dat_w, data) + sim.set(self.wb.adr, addr) + sim.set(self.wb.we, we) + sim.set(self.wb.cyc, 1) + sim.set(self.wb.stb, 1) - def master_release(self, release_cyc=1): - yield self.wb.stb.eq(0) + def master_release(self, sim: SimulatorContext, release_cyc: bool = True): + sim.set(self.wb.stb, 0) if release_cyc: - yield self.wb.cyc.eq(0) - - def master_verify(self, exp_data=0): - assert (yield self.wb.ack) - assert (yield self.wb.dat_r) == exp_data - - def slave_wait(self): - while not ((yield self.wb.stb) and (yield self.wb.cyc)): - yield Tick() - - def slave_verify(self, exp_addr, exp_data, exp_we, exp_sel=0): - assert (yield self.wb.stb) and (yield self.wb.cyc) - - assert (yield self.wb.adr) == exp_addr - assert (yield self.wb.we) == exp_we - assert (yield self.wb.sel) == exp_sel + sim.set(self.wb.cyc, 0) + + async def slave_wait(self, sim: SimulatorContext): + *_, adr, we, sel, dat_w = ( + await sim.tick() + .sample(self.wb.adr, self.wb.we, self.wb.sel, self.wb.dat_w) + .until(self.wb.stb & self.wb.cyc) + ) + return adr, we, sel, dat_w + + async def slave_wait_and_verify( + self, sim: SimulatorContext, exp_addr: int, exp_data: int, exp_we: int, exp_sel: int = 0 + ): + adr, we, sel, dat_w = await self.slave_wait(sim) + + assert adr == exp_addr + assert we == exp_we + assert sel == exp_sel if exp_we: - assert (yield self.wb.dat_w) == exp_data - - def slave_respond(self, data, ack=1, err=0, rty=0): - assert (yield self.wb.stb) and (yield self.wb.cyc) - - yield self.wb.dat_r.eq(data) - yield self.wb.ack.eq(ack) - yield self.wb.err.eq(err) - yield self.wb.rty.eq(rty) - yield Tick() - yield self.wb.ack.eq(0) - yield self.wb.err.eq(0) - yield self.wb.rty.eq(0) - - def wait_ack(self): - while not ((yield self.wb.stb) and (yield self.wb.cyc) and (yield self.wb.ack)): - yield Tick() + assert dat_w == exp_data + + async def slave_tick_and_verify( + self, sim: SimulatorContext, exp_addr: int, exp_data: int, exp_we: int, exp_sel: int = 0 + ): + *_, adr, we, sel, dat_w, stb, cyc = await sim.tick().sample( + self.wb.adr, self.wb.we, self.wb.sel, self.wb.dat_w, self.wb.stb, self.wb.cyc + ) + assert stb and cyc + + assert adr == exp_addr + assert we == exp_we + assert sel == exp_sel + if exp_we: + assert dat_w == exp_data + + async def slave_respond( + self, + sim: SimulatorContext, + data: int, + ack: int = 1, + err: int = 0, + rty: int = 0, + sample: Iterable[ValueLike] = (), + ): + assert sim.get(self.wb.stb) and sim.get(self.wb.cyc) + + sim.set(self.wb.dat_r, data) + sim.set(self.wb.ack, ack) + sim.set(self.wb.err, err) + sim.set(self.wb.rty, rty) + ret = await sim.tick().sample(*sample) + sim.set(self.wb.ack, 0) + sim.set(self.wb.err, 0) + sim.set(self.wb.rty, 0) + return ret + + async def slave_respond_master_verify( + self, sim: SimulatorContext, master: WishboneInterface, data: int, ack: int = 1, err: int = 0, rty: int = 0 + ): + *_, ack, dat_r = await self.slave_respond(sim, data, ack, err, rty, sample=[master.ack, master.dat_r]) + assert ack and dat_r == data + + async def wait_ack(self, sim: SimulatorContext): + await sim.tick().until(self.wb.stb & self.wb.cyc & self.wb.ack) class TestWishboneMaster(TestCaseWithSimulator): @@ -75,71 +107,63 @@ def elaborate(self, platform): def test_manual(self): twbm = TestWishboneMaster.WishboneMasterTestModule() - def process(): + async def process(sim: TestbenchContext): # read request - yield from twbm.requestAdapter.call(addr=2, data=0, we=0, sel=1) + await twbm.requestAdapter.call(sim, addr=2, data=0, we=0, sel=1) # read request after delay - yield Tick() - yield Tick() - yield from twbm.requestAdapter.call(addr=1, data=0, we=0, sel=1) + await sim.tick() + await sim.tick() + await twbm.requestAdapter.call(sim, addr=1, data=0, we=0, sel=1) # write request - yield from twbm.requestAdapter.call(addr=3, data=5, we=1, sel=0) + await twbm.requestAdapter.call(sim, addr=3, data=5, we=1, sel=0) # RTY and ERR responese - yield from twbm.requestAdapter.call(addr=2, data=0, we=0, sel=0) - resp = yield from twbm.requestAdapter.call_try(addr=0, data=0, we=0, sel=0) + await twbm.requestAdapter.call(sim, addr=2, data=0, we=0, sel=0) + resp = await twbm.requestAdapter.call_try(sim, addr=0, data=0, we=0, sel=0) assert resp is None # verify cycle restart - def result_process(): - resp = yield from twbm.resultAdapter.call() + async def result_process(sim: TestbenchContext): + resp = await twbm.resultAdapter.call(sim) assert resp["data"] == 8 assert not resp["err"] - resp = yield from twbm.resultAdapter.call() + resp = await twbm.resultAdapter.call(sim) assert resp["data"] == 3 assert not resp["err"] - resp = yield from twbm.resultAdapter.call() + resp = await twbm.resultAdapter.call(sim) assert not resp["err"] - resp = yield from twbm.resultAdapter.call() + resp = await twbm.resultAdapter.call(sim) assert resp["data"] == 1 assert resp["err"] - def slave(): + async def slave(sim: TestbenchContext): wwb = WishboneInterfaceWrapper(twbm.wbm.wb_master) - yield from wwb.slave_wait() - yield from wwb.slave_verify(2, 0, 0, 1) - yield from wwb.slave_respond(8) - yield Settle() + await wwb.slave_wait_and_verify(sim, 2, 0, 0, 1) + await wwb.slave_respond(sim, 8) - yield from wwb.slave_wait() - yield from wwb.slave_verify(1, 0, 0, 1) - yield from wwb.slave_respond(3) - yield Settle() + await wwb.slave_wait_and_verify(sim, 1, 0, 0, 1) + await wwb.slave_respond(sim, 3) - yield # consecutive request - yield from wwb.slave_verify(3, 5, 1, 0) - yield from wwb.slave_respond(0) - yield Tick() + await wwb.slave_tick_and_verify(sim, 3, 5, 1, 0) + await wwb.slave_respond(sim, 0) + await sim.tick() - yield # consecutive request - yield from wwb.slave_verify(2, 0, 0, 0) - yield from wwb.slave_respond(1, ack=0, err=0, rty=1) - yield Settle() - assert not (yield wwb.wb.stb) + await wwb.slave_tick_and_verify(sim, 2, 0, 0, 0) + await wwb.slave_respond(sim, 1, ack=0, err=0, rty=1) + assert not sim.get(wwb.wb.stb) - yield from wwb.slave_wait() - yield from wwb.slave_verify(2, 0, 0, 0) - yield from wwb.slave_respond(1, ack=1, err=1, rty=0) + await wwb.slave_wait_and_verify(sim, 2, 0, 0, 0) + await wwb.slave_respond(sim, 1, ack=1, err=1, rty=0) with self.run_simulation(twbm) as sim: - sim.add_process(process) - sim.add_process(result_process) - sim.add_process(slave) + sim.add_testbench(process) + sim.add_testbench(result_process) + sim.add_testbench(slave) class TestWishboneMuxer(TestCaseWithSimulator): @@ -149,97 +173,80 @@ def test_manual(self): slaves = [WishboneInterfaceWrapper(slave) for slave in mux.slaves] wb_master = WishboneInterfaceWrapper(mux.master_wb) - def process(): + async def process(sim: TestbenchContext): # check full communiaction - yield from wb_master.master_set(2, 0, 1) - yield mux.sselTGA.eq(0b0001) - yield Tick() - yield from slaves[0].slave_verify(2, 0, 1) - assert not (yield slaves[1].wb.stb) - yield from slaves[0].slave_respond(4) - yield from wb_master.master_verify(4) - yield from wb_master.master_release(release_cyc=0) - yield Tick() + wb_master.master_set(sim, 2, 0, 1) + sim.set(mux.sselTGA, 0b0001) + await slaves[0].slave_tick_and_verify(sim, 2, 0, 1) + assert not sim.get(slaves[1].wb.stb) + await slaves[0].slave_respond_master_verify(sim, wb_master.wb, 4) + wb_master.master_release(sim, release_cyc=False) + await sim.tick() # select without releasing cyc (only on stb) - yield from wb_master.master_set(3, 0, 0) - yield mux.sselTGA.eq(0b0010) - yield Tick() - assert not (yield slaves[0].wb.stb) - yield from slaves[1].slave_verify(3, 0, 0) - yield from slaves[1].slave_respond(5) - yield from wb_master.master_verify(5) - yield from wb_master.master_release() - yield Tick() + wb_master.master_set(sim, 3, 0, 0) + sim.set(mux.sselTGA, 0b0010) + await slaves[1].slave_tick_and_verify(sim, 3, 0, 0) + assert not sim.get(slaves[0].wb.stb) + await slaves[1].slave_respond_master_verify(sim, wb_master.wb, 5) + wb_master.master_release(sim) + await sim.tick() # normal selection - yield from wb_master.master_set(6, 0, 0) - yield mux.sselTGA.eq(0b1000) - yield Tick() - yield from slaves[3].slave_verify(6, 0, 0) - yield from slaves[3].slave_respond(1) - yield from wb_master.master_verify(1) + wb_master.master_set(sim, 6, 0, 0) + sim.set(mux.sselTGA, 0b1000) + await slaves[3].slave_tick_and_verify(sim, 6, 0, 0) + await slaves[3].slave_respond_master_verify(sim, wb_master.wb, 1) with self.run_simulation(mux) as sim: - sim.add_process(process) + sim.add_testbench(process) -class TestWishboneAribiter(TestCaseWithSimulator): +class TestWishboneArbiter(TestCaseWithSimulator): def test_manual(self): arb = WishboneArbiter(WishboneParameters(), 2) slave = WishboneInterfaceWrapper(arb.slave_wb) masters = [WishboneInterfaceWrapper(master) for master in arb.masters] - def process(): - yield from masters[0].master_set(2, 3, 1) - yield from slave.slave_wait() - yield from slave.slave_verify(2, 3, 1) - yield from masters[1].master_set(1, 4, 1) - yield from slave.slave_respond(0) - - yield from masters[0].master_verify() - assert not (yield masters[1].wb.ack) - yield from masters[0].master_release() - yield Tick() + async def process(sim: TestbenchContext): + masters[0].master_set(sim, 2, 3, 1) + await slave.slave_wait_and_verify(sim, 2, 3, 1) + masters[1].master_set(sim, 1, 4, 1) + await slave.slave_respond_master_verify(sim, masters[0].wb, 0) + assert not sim.get(masters[1].wb.ack) + masters[0].master_release(sim) + await sim.tick() # check if bus is granted to next master if previous ends cycle - yield from slave.slave_wait() - yield from slave.slave_verify(1, 4, 1) - yield from slave.slave_respond(0) - yield from masters[1].master_verify() - assert not (yield masters[0].wb.ack) - yield from masters[1].master_release() - yield Tick() + await slave.slave_wait_and_verify(sim, 1, 4, 1) + await slave.slave_respond_master_verify(sim, masters[1].wb, 0) + assert not sim.get(masters[0].wb.ack) + masters[1].master_release(sim) + await sim.tick() # check round robin behaviour (2 masters requesting *2) - yield from masters[0].master_set(1, 0, 0) - yield from masters[1].master_set(2, 0, 0) - yield from slave.slave_wait() - yield from slave.slave_verify(1, 0, 0) - yield from slave.slave_respond(3) - yield from masters[0].master_verify(3) - yield from masters[0].master_release() - yield from masters[1].master_release() - yield Tick() - assert not (yield slave.wb.cyc) - - yield from masters[0].master_set(1, 0, 0) - yield from masters[1].master_set(2, 0, 0) - yield from slave.slave_wait() - yield from slave.slave_verify(2, 0, 0) - yield from slave.slave_respond(0) - yield from masters[1].master_verify() + masters[0].master_set(sim, 1, 0, 0) + masters[1].master_set(sim, 2, 0, 0) + await slave.slave_wait_and_verify(sim, 1, 0, 0) + await slave.slave_respond_master_verify(sim, masters[0].wb, 3) + masters[0].master_release(sim) + masters[1].master_release(sim) + await sim.tick() + assert not sim.get(slave.wb.cyc) + + masters[0].master_set(sim, 1, 0, 0) + masters[1].master_set(sim, 2, 0, 0) + await slave.slave_wait_and_verify(sim, 2, 0, 0) + await slave.slave_respond_master_verify(sim, masters[1].wb, 0) # check if releasing stb keeps grant - yield from masters[1].master_release(release_cyc=0) - yield Tick() - yield from masters[1].master_set(3, 0, 0) - yield from slave.slave_wait() - yield from slave.slave_verify(3, 0, 0) - yield from slave.slave_respond(0) - yield from masters[1].master_verify() + masters[1].master_release(sim, release_cyc=False) + await sim.tick() + masters[1].master_set(sim, 3, 0, 0) + await slave.slave_wait_and_verify(sim, 3, 0, 0) + await slave.slave_respond_master_verify(sim, masters[1].wb, 0) with self.run_simulation(arb) as sim: - sim.add_process(process) + sim.add_testbench(process) class TestPipelinedWishboneMaster(TestCaseWithSimulator): @@ -254,7 +261,7 @@ def test_randomized(self): wb_params = WishboneParameters() pwbm = SimpleTestCircuit(PipelinedWishboneMaster((wb_params))) - def request_process(): + async def request_process(sim: TestbenchContext): for _ in range(requests): request = { "addr": random.randint(0, 2**wb_params.addr_width - 1), @@ -263,49 +270,46 @@ def request_process(): "sel": random.randint(0, 2**wb_params.granularity - 1), } req_queue.appendleft(request) - yield from pwbm.request.call(request) + await pwbm.request.call(sim, request) - def verify_process(): + async def verify_process(sim: TestbenchContext): for _ in range(requests): - while random.random() < 0.8: - yield Tick() + await self.random_wait_geom(sim, 0.8) - result = yield from pwbm.result.call() + result = await pwbm.result.call(sim) cres = res_queue.pop() assert result["data"] == cres assert not result["err"] - def slave_process(): - yield Passive() - + async def slave_process(sim: TestbenchContext): wbw = pwbm._dut.wb - while True: - if (yield wbw.cyc) and (yield wbw.stb): - assert not (yield wbw.stall) + async for *_, cyc, stb, stall, adr, dat_w, we, sel in sim.tick().sample( + wbw.cyc, wbw.stb, wbw.stall, wbw.adr, wbw.dat_w, wbw.we, wbw.sel + ): + if cyc and stb: + assert not stall assert req_queue c_req = req_queue.pop() - assert (yield wbw.adr) == c_req["addr"] - assert (yield wbw.dat_w) == c_req["data"] - assert (yield wbw.we) == c_req["we"] - assert (yield wbw.sel) == c_req["sel"] + assert adr == c_req["addr"] + assert dat_w == c_req["data"] + assert we == c_req["we"] + assert sel == c_req["sel"] - slave_queue.appendleft((yield wbw.dat_w)) - res_queue.appendleft((yield wbw.dat_w)) + slave_queue.appendleft(dat_w) + res_queue.appendleft(dat_w) if slave_queue and random.random() < 0.4: - yield wbw.ack.eq(1) - yield wbw.dat_r.eq(slave_queue.pop()) + sim.set(wbw.ack, 1) + sim.set(wbw.dat_r, slave_queue.pop()) else: - yield wbw.ack.eq(0) - - yield wbw.stall.eq(random.random() < 0.3) + sim.set(wbw.ack, 0) - yield Tick() + sim.set(wbw.stall, random.random() < 0.3) with self.run_simulation(pwbm) as sim: - sim.add_process(request_process) - sim.add_process(verify_process) - sim.add_process(slave_process) + sim.add_testbench(request_process) + sim.add_testbench(verify_process) + sim.add_testbench(slave_process, background=True) class WishboneMemorySlaveCircuit(Elaboratable): @@ -341,11 +345,10 @@ def setup_method(self): def test_randomized(self): req_queue = deque() - wr_queue = deque() mem_state = [0] * self.memsize - def request_process(): + async def request_process(sim: TestbenchContext): for _ in range(self.iters): req = { "addr": random.randint(0, self.memsize - 1), @@ -354,41 +357,27 @@ def request_process(): "sel": random.randint(0, 2**self.sel_width - 1), } req_queue.appendleft(req) - wr_queue.appendleft(req) - while random.random() < 0.2: - yield Tick() - yield from self.m.request.call(req) + await self.random_wait_geom(sim, 0.2) + await self.m.request.call(sim, req) - def result_process(): + async def result_process(sim: TestbenchContext): for _ in range(self.iters): - while random.random() < 0.2: - yield Tick() - res = yield from self.m.result.call() + await self.random_wait_geom(sim, 0.2) + res = await self.m.result.call(sim) req = req_queue.pop() if not req["we"]: assert res["data"] == mem_state[req["addr"]] - - def write_process(): - wwb = WishboneInterfaceWrapper(self.m.mem_master.wb_master) - for _ in range(self.iters): - yield from wwb.wait_ack() - req = wr_queue.pop() - - if req["we"]: + else: for i in range(self.sel_width): if req["sel"] & (1 << i): granularity_mask = (2**self.wb_params.granularity - 1) << (i * self.wb_params.granularity) mem_state[req["addr"]] &= ~granularity_mask mem_state[req["addr"]] |= req["data"] & granularity_mask - - yield Tick() - - if req["we"]: - assert (yield Value.cast(self.m.mem_slave.mem.data[req["addr"]])) == mem_state[req["addr"]] + val = sim.get(Value.cast(self.m.mem_slave.mem.data[req["addr"]])) + assert val == mem_state[req["addr"]] with self.run_simulation(self.m, max_cycles=3000) as sim: - sim.add_process(request_process) - sim.add_process(result_process) - sim.add_process(write_process) + sim.add_testbench(request_process) + sim.add_testbench(result_process) diff --git a/test/priv/traps/test_exception.py b/test/priv/traps/test_exception.py index 22ebb8b5e..824d892ba 100644 --- a/test/priv/traps/test_exception.py +++ b/test/priv/traps/test_exception.py @@ -38,16 +38,16 @@ def test_randomized(self): self.dut = SimpleTestCircuit( ExceptionInformationRegister( self.gen_params, self.rob_idx_mock.adapter.iface, self.fetch_stall_mock.adapter.iface - ) + ), ) m = ModuleConnector(self.dut, rob_idx_mock=self.rob_idx_mock, fetch_stall_mock=self.fetch_stall_mock) self.rob_id = 0 - def process_test(): + async def process_test(sim: TestbenchContext): saved_entry = None - yield from self.fetch_stall_mock.enable() + self.fetch_stall_mock.enable(sim) for _ in range(self.cycles): self.rob_id = random.randint(0, self.rob_max) @@ -61,12 +61,13 @@ def process_test(): report_arg = {"cause": cause, "rob_id": report_rob, "pc": report_pc, "mtval": report_mtval} expected = report_arg if self.should_update(report_arg, saved_entry, self.rob_id) else saved_entry - yield from self.dut.report.call(report_arg) - yield # additional FIFO delay + await self.dut.report.call(sim, report_arg) + # additional FIFO delay + *_, fetch_stall_mock_done = await self.fetch_stall_mock.sample_outputs_done(sim) - assert (yield from self.fetch_stall_mock.done()) + assert fetch_stall_mock_done - new_state = yield from self.dut.get.call() + new_state = data_const_to_dict(await self.dut.get.call(sim)) assert new_state == expected | {"valid": 1} # type: ignore @@ -77,4 +78,4 @@ def process_rob_idx_mock(): return {"start": self.rob_id, "end": 0} with self.run_simulation(m) as sim: - sim.add_process(process_test) + sim.add_testbench(process_test) diff --git a/test/regression/pysim.py b/test/regression/pysim.py index a21b293fe..ee8aa5990 100644 --- a/test/regression/pysim.py +++ b/test/regression/pysim.py @@ -2,22 +2,22 @@ import os import logging -from amaranth.sim import Passive, Settle, Tick from amaranth.utils import exact_log2 from amaranth import * from transactron.core.keys import TransactionManagerKey +from transactron.profiler import Profile +from transactron.testing.tick_count import make_tick_count_process from .memory import * from .common import SimulationBackend, SimulationExecutionResult from transactron.testing import ( PysimSimulator, - TestGen, profiler_process, - Profile, make_logging_process, parse_logging_level, + TestbenchContext, ) from transactron.utils.dependencies import DependencyContext, DependencyManager from transactron.lib.metrics import HardwareMetricsManager @@ -43,22 +43,20 @@ def __init__(self, traces_file: Optional[str] = None): def _wishbone_slave( self, mem_model: CoreMemoryModel, wb_ctrl: WishboneInterfaceWrapper, is_instr_bus: bool, delay: int = 0 ): - def f(): - yield Passive() - + async def f(sim: TestbenchContext): while True: - yield from wb_ctrl.slave_wait() + await wb_ctrl.slave_wait(sim) word_width_bytes = self.gp.isa.xlen // 8 # Wishbone is addressing words, so we need to shift it a bit to get the real address. - addr = (yield wb_ctrl.wb.adr) << exact_log2(word_width_bytes) - sel = yield wb_ctrl.wb.sel - dat_w = yield wb_ctrl.wb.dat_w + addr = sim.get(wb_ctrl.wb.adr) << exact_log2(word_width_bytes) + sel = sim.get(wb_ctrl.wb.sel) + dat_w = sim.get(wb_ctrl.wb.dat_w) resp_data = 0 - if (yield wb_ctrl.wb.we): + if sim.get(wb_ctrl.wb.we): resp = mem_model.write( WriteRequest(addr=addr, data=dat_w, byte_count=word_width_bytes, byte_sel=sel) ) @@ -83,21 +81,19 @@ def f(): rty = 1 for _ in range(delay): - yield Tick() - - yield from wb_ctrl.slave_respond(resp_data, ack=ack, err=err, rty=rty) + await sim.tick() - yield Settle() + await wb_ctrl.slave_respond(sim, resp_data, ack=ack, err=err, rty=rty) return f - def _waiter(self, on_finish: Callable[[], TestGen[None]]): - def f(): + def _waiter(self, on_finish: Callable[[TestbenchContext], None]): + async def f(sim: TestbenchContext): while self.running: self.cycle_cnt += 1 - yield Tick() + await sim.tick() - yield from on_finish() + on_finish(sim) return f @@ -141,13 +137,14 @@ async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int = 5000) -> S self.cycle_cnt = 0 sim = PysimSimulator(core, max_cycles=timeout_cycles, traces_file=self.traces_file) - sim.add_process(self._wishbone_slave(mem_model, wb_instr_ctrl, is_instr_bus=True)) - sim.add_process(self._wishbone_slave(mem_model, wb_data_ctrl, is_instr_bus=False)) + sim.add_testbench(self._wishbone_slave(mem_model, wb_instr_ctrl, is_instr_bus=True), background=True) + sim.add_testbench(self._wishbone_slave(mem_model, wb_data_ctrl, is_instr_bus=False), background=True) def on_error(): raise RuntimeError("Simulation finished due to an error") sim.add_process(make_logging_process(self.log_level, self.log_filter, on_error)) + sim.add_process(make_tick_count_process()) # This enables logging in benchmarks. TODO: after unifying regression testing, remove. logging.basicConfig() @@ -161,17 +158,17 @@ def on_error(): metric_values: dict[str, dict[str, int]] = {} - def on_sim_finish(): + def on_sim_finish(sim: TestbenchContext): # Collect metric values before we finish the simulation for metric_name, metric in self.metrics_manager.get_metrics().items(): metric = self.metrics_manager.get_metrics()[metric_name] metric_values[metric_name] = {} for reg_name in metric.regs: - metric_values[metric_name][reg_name] = yield self.metrics_manager.get_register_value( - metric_name, reg_name + metric_values[metric_name][reg_name] = sim.get( + self.metrics_manager.get_register_value(metric_name, reg_name) ) - sim.add_process(self._waiter(on_finish=on_sim_finish)) + sim.add_testbench(self._waiter(on_finish=on_sim_finish)) success = sim.run() self.pretty_dump_metrics(metric_values) diff --git a/test/scheduler/test_rs_selection.py b/test/scheduler/test_rs_selection.py index d00ac64f3..9a7e7d48b 100644 --- a/test/scheduler/test_rs_selection.py +++ b/test/scheduler/test_rs_selection.py @@ -2,7 +2,6 @@ import random from amaranth import * -from amaranth.sim import Settle, Passive from coreblocks.params import GenParams from coreblocks.interface.layouts import RSLayouts, SchedulerLayouts @@ -11,7 +10,9 @@ from coreblocks.params.configurations import test_core_config from coreblocks.scheduler.scheduler import RSSelection from transactron.lib import FIFO, Adapter, AdapterTrans -from transactron.testing import TestCaseWithSimulator, TestbenchIO +from transactron.testing import TestCaseWithSimulator, TestbenchIO, TestbenchContext +from transactron.testing.functions import data_const_to_dict +from transactron.testing.method_mock import MethodMock, def_method_mock _rs1_optypes = {OpType.ARITHMETIC, OpType.COMPARE} _rs2_optypes = {OpType.LOGIC, OpType.COMPARE} @@ -52,12 +53,12 @@ class TestRSSelect(TestCaseWithSimulator): def setup_method(self): self.gen_params = GenParams(test_core_config) self.m = RSSelector(self.gen_params) - self.expected_out = deque() - self.instr_in = deque() + self.expected_out: deque[dict] = deque() + self.instr_in: deque[dict] = deque() random.seed(1789) def create_instr_input_process(self, instr_count: int, optypes: set[OpType], random_wait: int = 0): - def process(): + async def process(sim: TestbenchContext): for i in range(instr_count): rp_dst = random.randrange(self.gen_params.phys_regs_bits) rp_s1 = random.randrange(self.gen_params.phys_regs_bits) @@ -91,41 +92,36 @@ def process(): } self.instr_in.append(instr) - yield from self.m.instr_in.call(instr) - yield from self.random_wait(random_wait) + await self.m.instr_in.call(sim, instr) + await self.random_wait(sim, random_wait) return process - def create_rs_alloc_process(self, io: TestbenchIO, rs_id: int, rs_optypes: set[OpType], random_wait: int = 0): - def mock(): + def create_rs_alloc_process(self, io: TestbenchIO, rs_id: int, rs_optypes: set[OpType], enable_prob: float = 1): + @def_method_mock(lambda: io, enable=lambda: random.random() <= enable_prob) + def process(): random_entry = random.randrange(self.gen_params.max_rs_entries) - expected = self.instr_in.popleft() - assert expected["exec_fn"]["op_type"] in rs_optypes - expected["rs_entry_id"] = random_entry - expected["rs_selected"] = rs_id - self.expected_out.append(expected) - return {"rs_entry_id": random_entry} + @MethodMock.effect + def eff(): + expected = self.instr_in.popleft() + assert expected["exec_fn"]["op_type"] in rs_optypes + expected["rs_entry_id"] = random_entry + expected["rs_selected"] = rs_id + self.expected_out.append(expected) - def process(): - yield Passive() - while True: - yield from io.enable() - yield from io.method_handle(mock) - yield from io.disable() - yield from self.random_wait(random_wait) + return {"rs_entry_id": random_entry} - return process + return process() def create_output_process(self, instr_count: int, random_wait: int = 0): - def process(): + async def process(sim: TestbenchContext): for _ in range(instr_count): - result = yield from self.m.instr_out.call() + result = await self.m.instr_out.call(sim) outputs = self.expected_out.popleft() - yield from self.random_wait(random_wait) - yield Settle() - assert result == outputs + await self.random_wait(sim, random_wait) + assert data_const_to_dict(result) == outputs return process @@ -135,10 +131,10 @@ def test_base_functionality(self): """ with self.run_simulation(self.m, max_cycles=1500) as sim: - sim.add_process(self.create_instr_input_process(100, _rs1_optypes.union(_rs2_optypes))) - sim.add_process(self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes)) - sim.add_process(self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes)) - sim.add_process(self.create_output_process(100)) + sim.add_testbench(self.create_instr_input_process(100, _rs1_optypes.union(_rs2_optypes))) + self.add_mock(sim, self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes)) + self.add_mock(sim, self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes)) + sim.add_testbench(self.create_output_process(100)) def test_only_rs1(self): """ @@ -147,9 +143,9 @@ def test_only_rs1(self): """ with self.run_simulation(self.m, max_cycles=1500) as sim: - sim.add_process(self.create_instr_input_process(100, _rs1_optypes.intersection(_rs2_optypes))) - sim.add_process(self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes)) - sim.add_process(self.create_output_process(100)) + sim.add_testbench(self.create_instr_input_process(100, _rs1_optypes.intersection(_rs2_optypes))) + self.add_mock(sim, self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes)) + sim.add_testbench(self.create_output_process(100)) def test_only_rs2(self): """ @@ -158,9 +154,9 @@ def test_only_rs2(self): """ with self.run_simulation(self.m, max_cycles=1500) as sim: - sim.add_process(self.create_instr_input_process(100, _rs1_optypes.intersection(_rs2_optypes))) - sim.add_process(self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes)) - sim.add_process(self.create_output_process(100)) + sim.add_testbench(self.create_instr_input_process(100, _rs1_optypes.intersection(_rs2_optypes))) + self.add_mock(sim, self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes)) + sim.add_testbench(self.create_output_process(100)) def test_delays(self): """ @@ -169,11 +165,11 @@ def test_delays(self): """ with self.run_simulation(self.m, max_cycles=5000) as sim: - sim.add_process(self.create_instr_input_process(300, _rs1_optypes.union(_rs2_optypes), random_wait=4)) - sim.add_process( - self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes, random_wait=12) + sim.add_testbench(self.create_instr_input_process(300, _rs1_optypes.union(_rs2_optypes), random_wait=4)) + self.add_mock( + sim, self.create_rs_alloc_process(self.m.rs1_alloc, rs_id=0, rs_optypes=_rs1_optypes, enable_prob=0.1) ) - sim.add_process( - self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes, random_wait=12) + self.add_mock( + sim, self.create_rs_alloc_process(self.m.rs2_alloc, rs_id=1, rs_optypes=_rs2_optypes, enable_prob=0.1) ) - sim.add_process(self.create_output_process(300, random_wait=12)) + sim.add_testbench(self.create_output_process(300, random_wait=12)) diff --git a/test/scheduler/test_scheduler.py b/test/scheduler/test_scheduler.py index 293ce201f..9c54c043b 100644 --- a/test/scheduler/test_scheduler.py +++ b/test/scheduler/test_scheduler.py @@ -3,15 +3,15 @@ from collections import namedtuple, deque from typing import Callable, Optional, Iterable from amaranth import * -from amaranth.lib.data import View -from amaranth.sim import Settle, Tick from parameterized import parameterized_class from coreblocks.interface.keys import CoreStateKey -from coreblocks.interface.layouts import ROBLayouts, RetirementLayouts +from coreblocks.interface.layouts import RetirementLayouts from coreblocks.func_blocks.fu.common.rs_func_block import RSBlockComponent from transactron.core import Method from transactron.lib import FIFO, AdapterTrans, Adapter +from transactron.testing.functions import MethodData, data_const_to_dict +from transactron.testing.method_mock import MethodMock from transactron.utils.dependencies import DependencyContext from coreblocks.scheduler.scheduler import Scheduler from coreblocks.core_structs.rf import RegisterFile @@ -22,7 +22,7 @@ from coreblocks.params.configurations import test_core_config from coreblocks.core_structs.rob import ReorderBuffer from coreblocks.func_blocks.interface.func_protocols import FuncBlock -from transactron.testing import RecordIntDict, TestCaseWithSimulator, TestGen, TestbenchIO, def_method_mock +from transactron.testing import TestCaseWithSimulator, TestbenchIO, def_method_mock, TestbenchContext class SchedulerTestCircuit(Elaboratable): @@ -157,7 +157,7 @@ def free_phys_reg(self, reg_id): self.free_regs_queue.append({"reg_id": reg_id}) self.expected_phys_reg_queue.append(reg_id) - def queue_gather(self, queues: Iterable[deque]): + async def queue_gather(self, sim: TestbenchContext, queues: Iterable[deque]): # Iterate over all 'queues' and take one element from each, gathering # all key-value pairs into 'item'. item = {} @@ -166,6 +166,7 @@ def queue_gather(self, queues: Iterable[deque]): # retry until we get an element while partial_item is None: # get element from one queue + await sim.delay(1e-9) if q: partial_item = q.popleft() # None signals to end the process @@ -173,7 +174,7 @@ def queue_gather(self, queues: Iterable[deque]): return None else: # if no element available, wait and retry on the next clock cycle - yield Tick() + await sim.tick() # merge queue element with all previous ones (dict merge) item = item | partial_item @@ -185,7 +186,7 @@ def make_queue_process( io: TestbenchIO, input_queues: Optional[Iterable[deque]] = None, output_queues: Optional[Iterable[deque]] = None, - check: Optional[Callable[[RecordIntDict, RecordIntDict], TestGen[None]]] = None, + check: Optional[Callable[[TestbenchContext, MethodData, dict], None]] = None, always_enable: bool = False, ): """Create queue gather-and-test process @@ -235,31 +236,30 @@ def make_queue_process( If neither `input_queues` nor `output_queues` are supplied. """ - def queue_process(): + async def queue_process(sim: TestbenchContext): if always_enable: - yield from io.enable() + io.enable(sim) while True: inputs = {} outputs = {} # gather items from both queues if input_queues is not None: - inputs = yield from self.queue_gather(input_queues) + inputs = await self.queue_gather(sim, input_queues) if output_queues is not None: - outputs = yield from self.queue_gather(output_queues) + outputs = await self.queue_gather(sim, output_queues) # Check if queues signalled to end the process if inputs is None or outputs is None: return - result = yield from io.call(inputs) + result = await io.call(sim, inputs) if always_enable: - yield from io.enable() + io.enable(sim) # this could possibly be extended to automatically compare 'results' and # 'outputs' if check is None but that needs some dict deepcompare if check is not None: - yield Settle() - yield from check(result, outputs) + check(sim, result, outputs) if output_queues is None and input_queues is None: raise ValueError("Either output_queues or input_queues must be supplied") @@ -267,44 +267,39 @@ def queue_process(): return queue_process def make_output_process(self, io: TestbenchIO, output_queues: Iterable[deque]): - def check(got, expected): - rl_dst = yield View( - self.gen_params.get(ROBLayouts).data_layout, - C( - (yield Value.cast(self.m.rob.data.data[got["rs_data"]["rob_id"]])), - self.gen_params.get(ROBLayouts).data_layout.size, - ), - ).rl_dst + def check(sim: TestbenchContext, got: MethodData, expected: dict): + # TODO: better stubs for Memory? + rl_dst = sim.get(self.m.rob.data.data[got.rs_data.rob_id].rl_dst) # type: ignore s1 = self.rf_state[expected["rp_s1"]] s2 = self.rf_state[expected["rp_s2"]] # if source operand register ids are 0 then we already have values - assert got["rs_data"]["rp_s1"] == (expected["rp_s1"] if not s1.valid else 0) - assert got["rs_data"]["rp_s2"] == (expected["rp_s2"] if not s2.valid else 0) - assert got["rs_data"]["rp_dst"] == expected["rp_dst"] - assert got["rs_data"]["exec_fn"] == expected["exec_fn"] - assert got["rs_entry_id"] == expected["rs_entry_id"] - assert got["rs_data"]["s1_val"] == (s1.value if s1.valid else 0) - assert got["rs_data"]["s2_val"] == (s2.value if s2.valid else 0) + assert got.rs_data.rp_s1 == (expected["rp_s1"] if not s1.valid else 0) + assert got.rs_data.rp_s2 == (expected["rp_s2"] if not s2.valid else 0) + assert got.rs_data.rp_dst == expected["rp_dst"] + assert data_const_to_dict(got.rs_data.exec_fn) == expected["exec_fn"] + assert got.rs_entry_id == expected["rs_entry_id"] + assert got.rs_data.s1_val == (s1.value if s1.valid else 0) + assert got.rs_data.s2_val == (s2.value if s2.valid else 0) assert rl_dst == expected["rl_dst"] # recycle physical register number - if got["rs_data"]["rp_dst"] != 0: - self.free_phys_reg(got["rs_data"]["rp_dst"]) + if got.rs_data.rp_dst != 0: + self.free_phys_reg(got.rs_data.rp_dst) # recycle ROB entry - self.free_ROB_entries_queue.append({"rob_id": got["rs_data"]["rob_id"]}) + self.free_ROB_entries_queue.append({"rob_id": got.rs_data.rob_id}) return self.make_queue_process(io=io, output_queues=output_queues, check=check, always_enable=True) def test_randomized(self): - def instr_input_process(): - yield from self.m.rob_retire.enable() + async def instr_input_process(sim: TestbenchContext): + self.m.rob_retire.enable(sim) # set up RF to reflect our static rf_state reference lookup table for i in range(2**self.gen_params.phys_regs_bits - 1): - yield from self.m.rf_write.call(reg_id=i, reg_val=self.rf_state[i].value) + await self.m.rf_write.call(sim, reg_id=i, reg_val=self.rf_state[i].value) if not self.rf_state[i].valid: - yield from self.m.rf_free.call(reg_id=i) + await self.m.rf_free.call(sim, reg_id=i) op_types_set = set() for rs in self.optype_sets: @@ -338,7 +333,8 @@ def instr_input_process(): ) self.current_RAT[rl_dst] = rp_dst - yield from self.m.instr_inp.call( + await self.m.instr_inp.call( + sim, { "exec_fn": { "op_type": op_type, @@ -351,7 +347,7 @@ def instr_input_process(): "rl_dst": rl_dst, }, "imm": immediate, - } + }, ) # Terminate other processes self.expected_rename_queue.append(None) @@ -362,19 +358,22 @@ def rs_alloc_process(io: TestbenchIO, rs_id: int): @def_method_mock(lambda: io) def process(): random_entry = random.randrange(self.gen_params.max_rs_entries) - expected = self.expected_rename_queue.popleft() - expected["rs_entry_id"] = random_entry - self.expected_rs_entry_queue[rs_id].append(expected) - # if last instruction was allocated stop simulation - self.allocated_instr_count += 1 - if self.allocated_instr_count == self.instr_count: - for i in range(self.rs_count): - self.expected_rs_entry_queue[i].append(None) + @MethodMock.effect + def eff(): + expected = self.expected_rename_queue.popleft() + expected["rs_entry_id"] = random_entry + self.expected_rs_entry_queue[rs_id].append(expected) + + # if last instruction was allocated stop simulation + self.allocated_instr_count += 1 + if self.allocated_instr_count == self.instr_count: + for i in range(self.rs_count): + self.expected_rs_entry_queue[i].append(None) return {"rs_entry_id": random_entry} - return process + return process() @def_method_mock(lambda: self.m.core_state) def core_state_mock(): @@ -383,10 +382,10 @@ def core_state_mock(): with self.run_simulation(self.m, max_cycles=1500) as sim: for i in range(self.rs_count): - sim.add_process( + sim.add_testbench( self.make_output_process(io=self.m.rs_insert[i], output_queues=[self.expected_rs_entry_queue[i]]) ) - sim.add_process(rs_alloc_process(self.m.rs_alloc[i], i)) - sim.add_process(self.make_queue_process(io=self.m.rob_done, input_queues=[self.free_ROB_entries_queue])) - sim.add_process(self.make_queue_process(io=self.m.free_rf_inp, input_queues=[self.free_regs_queue])) - sim.add_process(instr_input_process) + self.add_mock(sim, rs_alloc_process(self.m.rs_alloc[i], i)) + sim.add_testbench(self.make_queue_process(io=self.m.rob_done, input_queues=[self.free_ROB_entries_queue])) + sim.add_testbench(self.make_queue_process(io=self.m.free_rf_inp, input_queues=[self.free_regs_queue])) + sim.add_testbench(instr_input_process) diff --git a/test/scheduler/test_wakeup_select.py b/test/scheduler/test_wakeup_select.py index b51af3cd3..cd34de905 100644 --- a/test/scheduler/test_wakeup_select.py +++ b/test/scheduler/test_wakeup_select.py @@ -1,7 +1,6 @@ from typing import Optional, cast from amaranth import * from amaranth.lib.data import StructLayout -from amaranth.sim import Settle, Tick from collections import deque from enum import Enum @@ -16,7 +15,8 @@ from transactron.lib import Adapter from coreblocks.scheduler.wakeup_select import * -from transactron.testing import RecordIntDict, TestCaseWithSimulator, TestbenchIO +from transactron.testing import RecordIntDict, TestCaseWithSimulator, TestbenchIO, TestbenchContext +from transactron.testing.functions import data_const_to_dict class WakeupTestCircuit(Elaboratable): @@ -76,46 +76,38 @@ def maybe_insert(self, rs: list[Optional[RecordIntDict]]): empty_idx -= 1 return 0 - def process(self): + async def process(self, sim: TestbenchContext): inserted_count = 0 issued_count = 0 rs: list[Optional[RecordIntDict]] = [None for _ in range(self.m.gen_params.max_rs_entries)] - yield from self.m.take_row_mock.enable() - yield from self.m.issue_mock.enable() - yield Settle() + self.m.take_row_mock.enable(sim) + self.m.issue_mock.enable(sim) for _ in range(self.cycles): inserted_count += self.maybe_insert(rs) - ready = Cat(entry is not None for entry in rs) + ready = Const.cast(Cat(entry is not None for entry in rs)) - yield from self.m.ready_mock.call_init(ready_list=ready) - if any(entry is not None for entry in rs): - yield from self.m.ready_mock.enable() - else: - yield from self.m.ready_mock.disable() + self.m.ready_mock.call_init(sim, ready_list=ready) + self.m.ready_mock.set_enable(sim, any(entry is not None for entry in rs)) - yield Settle() - - take_position = yield from self.m.take_row_mock.call_result() + take_position = self.m.take_row_mock.get_call_result(sim) if take_position is not None: take_position = cast(int, take_position["rs_entry_id"]) entry = rs[take_position] assert entry is not None self.taken.append(entry) - yield from self.m.take_row_mock.call_init(entry) + self.m.take_row_mock.call_init(sim, entry) rs[take_position] = None - yield Settle() - - issued = yield from self.m.issue_mock.call_result() + issued = self.m.issue_mock.get_call_result(sim) if issued is not None: - assert issued == self.taken.popleft() + assert data_const_to_dict(issued) == self.taken.popleft() issued_count += 1 - yield Tick() + await sim.tick() assert inserted_count != 0 assert inserted_count == issued_count def test(self): with self.run_simulation(self.m) as sim: - sim.add_process(self.process) + sim.add_testbench(self.process) diff --git a/test/test_core.py b/test/test_core.py index 6825ffd1b..237360f53 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -1,23 +1,29 @@ +from collections.abc import Callable +from typing import Any from amaranth import * from amaranth.lib.wiring import connect -from amaranth.sim import Passive, Tick +from amaranth_types import ValueLike +from transactron.testing.tick_count import TicksKey from transactron.utils import align_to_power_of_two -from transactron.testing import TestCaseWithSimulator +from transactron.testing import TestCaseWithSimulator, ProcessContext, TestbenchContext from coreblocks.arch.isa_consts import PrivilegeLevel from coreblocks.core import Core from coreblocks.params import GenParams from coreblocks.params.instr import * -from coreblocks.params.configurations import CoreConfiguration, basic_core_config, full_core_config +from coreblocks.params.configurations import * from coreblocks.peripherals.wishbone import WishboneMemorySlave +from coreblocks.priv.traps.interrupt_controller import ISA_RESERVED_INTERRUPTS import random import subprocess import tempfile from parameterized import parameterized_class +from transactron.utils.dependencies import DependencyContext + class CoreTestElaboratable(Elaboratable): def __init__(self, gen_params: GenParams, instr_mem: list[int] = [0], data_mem: list[int] = []): @@ -39,10 +45,12 @@ def elaborate(self, platform): self.core = Core(gen_params=self.gen_params) - self.interrupt_level = Signal() - self.interrupt_edge = Signal() - - m.d.comb += self.core.interrupt_controller.custom_report.eq(Cat(self.interrupt_edge, self.interrupt_level)) + if self.gen_params.interrupt_custom_count == 2: + self.interrupt_level = Signal() + self.interrupt_edge = Signal() + m.d.comb += self.core.interrupts.eq( + Cat(self.interrupt_edge, self.interrupt_level) << ISA_RESERVED_INTERRUPTS + ) m.submodules.wb_mem_slave = self.wb_mem_slave m.submodules.wb_mem_slave_data = self.wb_mem_slave_data @@ -58,11 +66,11 @@ class TestCoreBase(TestCaseWithSimulator): gen_params: GenParams m: CoreTestElaboratable - def get_phys_reg_rrat(self, reg_id): - return (yield self.m.core.RRAT.entries[reg_id]) + def get_phys_reg_rrat(self, sim: TestbenchContext, reg_id): + return sim.get(self.m.core.RRAT.entries[reg_id]) - def get_arch_reg_val(self, reg_id): - return (yield self.m.core.RF.entries[(yield from self.get_phys_reg_rrat(reg_id))].reg_val) + def get_arch_reg_val(self, sim: TestbenchContext, reg_id): + return sim.get(self.m.core.RF.entries[(self.get_phys_reg_rrat(sim, reg_id))].reg_val) class TestCoreAsmSourceBase(TestCoreBase): @@ -131,6 +139,7 @@ def load_section(section: str): [ ("fibonacci", "fibonacci.asm", 500, {2: 2971215073}, basic_core_config), ("fibonacci_mem", "fibonacci_mem.asm", 400, {3: 55}, basic_core_config), + ("fibonacci_mem_tiny", "fibonacci_mem.asm", 250, {3: 55}, tiny_core_config), ("csr", "csr.asm", 200, {1: 1, 2: 4}, full_core_config), ("csr_mmode", "csr_mmode.asm", 1000, {1: 0, 2: 44, 3: 0, 4: 0, 5: 0, 6: 4, 15: 0}, full_core_config), ("exception", "exception.asm", 200, {1: 1, 2: 2}, basic_core_config), @@ -147,12 +156,11 @@ class TestCoreBasicAsm(TestCoreAsmSourceBase): expected_regvals: dict[int, int] configuration: CoreConfiguration - def run_and_check(self): - for _ in range(self.cycle_count): - yield Tick() + async def run_and_check(self, sim: TestbenchContext): + await self.tick(sim, self.cycle_count) for reg_id, val in self.expected_regvals.items(): - assert (yield from self.get_arch_reg_val(reg_id)) == val + assert self.get_arch_reg_val(sim, reg_id) == val def test_asm_source(self): self.gen_params = GenParams(self.configuration) @@ -165,7 +173,7 @@ def test_asm_source(self): self.m = CoreTestElaboratable(self.gen_params, instr_mem=bin_src["text"], data_mem=bin_src["data"]) with self.run_simulation(self.m) as sim: - sim.add_process(self.run_and_check) + sim.add_testbench(self.run_and_check) # test interrupts with varying triggering frequency (parametrizable amount of cycles between @@ -186,6 +194,15 @@ def test_asm_source(self): ("interrupt.asm", 600, {4: 89, 8: 843}, {2: 89, 7: 843, 31: 0xDE}, 30, 50, False), # interrupts are only inserted on branches, we always have some forward progression. 15 for trigger variantion. ("interrupt.asm", 80, {4: 21, 8: 9349}, {2: 21, 7: 9349, 31: 0xDE}, 0, 15, False), + ( + "interrupt_vectored.asm", + 200, + {4: 21, 8: 9349, 15: 24476}, + {2: 21, 7: 9349, 14: 24476, 31: 0xDE, 16: 0x201, 17: 0x111}, + 0, + 15, + False, + ), ("wfi_int.asm", 80, {2: 10}, {2: 10, 3: 10}, 5, 15, True), ], ) @@ -207,40 +224,37 @@ def setup_method(self): self.gen_params = GenParams(self.configuration) random.seed(1500100900) - def clear_level_interrupt_procsess(self): - yield Passive() - while True: - while (yield self.m.core.csr_generic.csr_coreblocks_test.value) == 0: - yield Tick() + async def clear_level_interrupt_process(self, sim: ProcessContext): + async for *_, value in sim.tick().sample(self.m.core.csr_generic.csr_coreblocks_test.value): + if value == 0: + continue - if (yield self.m.core.csr_generic.csr_coreblocks_test.value) == 2: + if value == 2: assert False, "`fail` called" - yield self.m.core.csr_generic.csr_coreblocks_test.value.eq(0) - yield self.m.interrupt_level.eq(0) - yield Tick() + sim.set(self.m.core.csr_generic.csr_coreblocks_test.value, 0) + sim.set(self.m.interrupt_level, 0) - def run_with_interrupt_process(self): + async def run_with_interrupt_process(self, sim: TestbenchContext): main_cycles = 0 int_count = 0 handler_count = 0 # wait for interrupt enable - while (yield self.m.core.interrupt_controller.mstatus_mie.value) == 0: - yield Tick() + await sim.tick().until(self.m.core.interrupt_controller.mstatus_mie.value) - def do_interrupt(): + async def do_interrupt(): count = 0 trig = random.randint(1, 3) - mie = (yield self.m.core.interrupt_controller.mie.value) >> 16 + mie = sim.get(self.m.core.interrupt_controller.mie.value) >> 16 if mie != 0b11 or trig & 1 or self.edge_only: - yield self.m.interrupt_edge.eq(1) + sim.set(self.m.interrupt_edge, 1) count += 1 - if (mie != 0b11 or trig & 2) and (yield self.m.interrupt_level) == 0 and not self.edge_only: - yield self.m.interrupt_level.eq(1) + if (mie != 0b11 or trig & 2) and sim.get(self.m.interrupt_level) == 0 and not self.edge_only: + sim.set(self.m.interrupt_level, 1) count += 1 - yield Tick() - yield self.m.interrupt_edge.eq(0) + await sim.tick() + sim.set(self.m.interrupt_edge, 0) return count early_interrupt = False @@ -249,40 +263,35 @@ def do_interrupt(): # run main code for some semi-random amount of cycles c = random.randrange(self.lo, self.hi) main_cycles += c - yield from self.tick(c) + await self.tick(sim, c) # trigger an interrupt - int_count += yield from do_interrupt() + int_count += await do_interrupt() # wait for the interrupt to get registered - while (yield self.m.core.interrupt_controller.mstatus_mie.value) == 1: - yield Tick() + await sim.tick().until(self.m.core.interrupt_controller.mstatus_mie.value != 1) # trigger interrupt during execution of ISR handler (blocked-pending) with some chance early_interrupt = random.random() < 0.4 if early_interrupt: # wait until interrupts are cleared, so it won't be missed - while (yield self.m.core.interrupt_controller.mip.value) != 0: - yield Tick() - - assert (yield from self.get_arch_reg_val(30)) == int_count + await sim.tick().until(self.m.core.interrupt_controller.mip.value == 0) + assert self.get_arch_reg_val(sim, 30) == int_count - int_count += yield from do_interrupt() + int_count += await do_interrupt() else: - while (yield self.m.core.interrupt_controller.mip.value) != 0: - yield Tick() - assert (yield from self.get_arch_reg_val(30)) == int_count + await sim.tick().until(self.m.core.interrupt_controller.mip.value == 0) + assert self.get_arch_reg_val(sim, 30) == int_count handler_count += 1 # wait until ISR returns - while (yield self.m.core.interrupt_controller.mstatus_mie.value) == 0: - yield Tick() + await sim.tick().until(self.m.core.interrupt_controller.mstatus_mie.value != 0) - assert (yield from self.get_arch_reg_val(30)) == int_count - assert (yield from self.get_arch_reg_val(27)) == handler_count + assert self.get_arch_reg_val(sim, 30) == int_count + assert self.get_arch_reg_val(sim, 27) == handler_count for reg_id, val in self.expected_regvals.items(): - assert (yield from self.get_arch_reg_val(reg_id)) == val + assert self.get_arch_reg_val(sim, reg_id) == val def test_interrupted_prog(self): bin_src = self.prepare_source(self.source_file) @@ -290,14 +299,14 @@ def test_interrupted_prog(self): bin_src["data"][self.reg_init_mem_offset // 4 + reg_id] = val self.m = CoreTestElaboratable(self.gen_params, instr_mem=bin_src["text"], data_mem=bin_src["data"]) with self.run_simulation(self.m) as sim: - sim.add_process(self.run_with_interrupt_process) - sim.add_process(self.clear_level_interrupt_procsess) + sim.add_testbench(self.run_with_interrupt_process) + sim.add_process(self.clear_level_interrupt_process) @parameterized_class( ("source_file", "cycle_count", "expected_regvals", "always_mmode"), [ - ("user_mode.asm", 1000, {4: 5}, False), + ("user_mode.asm", 1100, {4: 5}, False), ("wfi_no_mie.asm", 250, {8: 8}, True), # only using level enable ], ) @@ -314,48 +323,45 @@ def setup_method(self): self.gen_params = GenParams(self.configuration) random.seed(161453) - def run_with_interrupt_process(self): - cycles = 0 + async def run_with_interrupt_process(self, sim: TestbenchContext): + ticks = DependencyContext.get().get_dependency(TicksKey()) # wait for interrupt enable - while (yield self.m.core.interrupt_controller.mie.value) == 0 and cycles < self.cycle_count: - cycles += 1 - yield Tick() - yield from self.random_wait(5) + async def wait_or_timeout(cond: ValueLike, pred: Callable[[Any], bool]): + async for *_, value in sim.tick().sample(cond): + if pred(value) or sim.get(ticks) >= self.cycle_count: + break + + await wait_or_timeout(self.m.core.interrupt_controller.mie.value, lambda value: value != 0) + await self.random_wait(sim, 5) - while cycles < self.cycle_count: - yield self.m.interrupt_level.eq(1) - cycles += 1 - yield Tick() + while sim.get(ticks) < self.cycle_count: + sim.set(self.m.interrupt_level, 1) if self.always_mmode: # if test happens only in m_mode, just enable fixed interrupt + await sim.tick() continue # wait for the interrupt to get registered - while ( - yield self.m.core.csr_generic.m_mode.priv_mode.value - ) != PrivilegeLevel.MACHINE and cycles < self.cycle_count: - cycles += 1 - yield Tick() + await wait_or_timeout( + self.m.core.csr_generic.m_mode.priv_mode.value, lambda value: value == PrivilegeLevel.MACHINE + ) - yield self.m.interrupt_level.eq(0) - yield Tick() + sim.set(self.m.interrupt_level, 0) # wait until ISR returns - while ( - yield self.m.core.csr_generic.m_mode.priv_mode.value - ) == PrivilegeLevel.MACHINE and cycles < self.cycle_count: - cycles += 1 - yield Tick() + await wait_or_timeout( + self.m.core.csr_generic.m_mode.priv_mode.value, lambda value: value != PrivilegeLevel.MACHINE + ) - yield from self.random_wait(5) + await self.random_wait(sim, 5) for reg_id, val in self.expected_regvals.items(): - assert (yield from self.get_arch_reg_val(reg_id)) == val + assert self.get_arch_reg_val(sim, reg_id) == val def test_interrupted_prog(self): bin_src = self.prepare_source(self.source_file) self.m = CoreTestElaboratable(self.gen_params, instr_mem=bin_src["text"], data_mem=bin_src["data"]) with self.run_simulation(self.m) as sim: - sim.add_process(self.run_with_interrupt_process) + sim.add_testbench(self.run_with_interrupt_process) diff --git a/test/transactron/core/__init__.py b/test/transactron/core/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/transactron/core/test_transactions.py b/test/transactron/core/test_transactions.py deleted file mode 100644 index 46ef5f6d7..000000000 --- a/test/transactron/core/test_transactions.py +++ /dev/null @@ -1,452 +0,0 @@ -from abc import abstractmethod -from unittest.case import TestCase -import pytest -from amaranth import * -from amaranth.sim import * - -import random -import contextlib - -from collections import deque -from typing import Iterable, Callable -from parameterized import parameterized, parameterized_class - -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout - -from transactron import * -from transactron.lib import Adapter, AdapterTrans -from transactron.utils import Scheduler - -from transactron.core import Priority -from transactron.core.schedulers import trivial_roundrobin_cc_scheduler, eager_deterministic_cc_scheduler -from transactron.core.manager import TransactionScheduler -from transactron.utils.dependencies import DependencyContext - - -class TestNames(TestCase): - def test_names(self): - mgr = TransactionManager() - mgr._MustUse__silence = True # type: ignore - - class T(Elaboratable): - def __init__(self): - self._MustUse__silence = True # type: ignore - Transaction(manager=mgr) - - T() - assert mgr.transactions[0].name == "T" - - t = Transaction(name="x", manager=mgr) - assert t.name == "x" - - t = Transaction(manager=mgr) - assert t.name == "t" - - m = Method(name="x") - assert m.name == "x" - - m = Method() - assert m.name == "m" - - -class TestScheduler(TestCaseWithSimulator): - def count_test(self, sched, cnt): - assert sched.count == cnt - assert len(sched.requests) == cnt - assert len(sched.grant) == cnt - assert len(sched.valid) == 1 - - def sim_step(self, sched, request, expected_grant): - yield sched.requests.eq(request) - yield Tick() - - if request == 0: - assert not (yield sched.valid) - else: - assert (yield sched.grant) == expected_grant - assert (yield sched.valid) - - def test_single(self): - sched = Scheduler(1) - self.count_test(sched, 1) - - def process(): - yield from self.sim_step(sched, 0, 0) - yield from self.sim_step(sched, 1, 1) - yield from self.sim_step(sched, 1, 1) - yield from self.sim_step(sched, 0, 0) - - with self.run_simulation(sched) as sim: - sim.add_process(process) - - def test_multi(self): - sched = Scheduler(4) - self.count_test(sched, 4) - - def process(): - yield from self.sim_step(sched, 0b0000, 0b0000) - yield from self.sim_step(sched, 0b1010, 0b0010) - yield from self.sim_step(sched, 0b1010, 0b1000) - yield from self.sim_step(sched, 0b1010, 0b0010) - yield from self.sim_step(sched, 0b1001, 0b1000) - yield from self.sim_step(sched, 0b1001, 0b0001) - - yield from self.sim_step(sched, 0b1111, 0b0010) - yield from self.sim_step(sched, 0b1111, 0b0100) - yield from self.sim_step(sched, 0b1111, 0b1000) - yield from self.sim_step(sched, 0b1111, 0b0001) - - yield from self.sim_step(sched, 0b0000, 0b0000) - yield from self.sim_step(sched, 0b0010, 0b0010) - yield from self.sim_step(sched, 0b0010, 0b0010) - - with self.run_simulation(sched) as sim: - sim.add_process(process) - - -class TransactionConflictTestCircuit(Elaboratable): - def __init__(self, scheduler): - self.scheduler = scheduler - - def elaborate(self, platform): - m = TModule() - tm = TransactionModule(m, DependencyContext.get(), TransactionManager(self.scheduler)) - adapter = Adapter(i=data_layout(32), o=data_layout(32)) - m.submodules.out = self.out = TestbenchIO(adapter) - m.submodules.in1 = self.in1 = TestbenchIO(AdapterTrans(adapter.iface)) - m.submodules.in2 = self.in2 = TestbenchIO(AdapterTrans(adapter.iface)) - return tm - - -@parameterized_class( - ("name", "scheduler"), - [ - ("trivial_roundrobin", trivial_roundrobin_cc_scheduler), - ("eager_deterministic", eager_deterministic_cc_scheduler), - ], -) -class TestTransactionConflict(TestCaseWithSimulator): - scheduler: TransactionScheduler - - def setup_method(self): - random.seed(42) - - def make_process( - self, io: TestbenchIO, prob: float, src: Iterable[int], tgt: Callable[[int], None], chk: Callable[[int], None] - ): - def process(): - for i in src: - while random.random() >= prob: - yield Tick() - tgt(i) - r = yield from io.call(data=i) - chk(r["data"]) - - return process - - def make_in1_process(self, prob: float): - def tgt(x: int): - self.out1_expected.append(x) - - def chk(x: int): - assert x == self.in_expected.popleft() - - return self.make_process(self.m.in1, prob, self.in1_stream, tgt, chk) - - def make_in2_process(self, prob: float): - def tgt(x: int): - self.out2_expected.append(x) - - def chk(x: int): - assert x == self.in_expected.popleft() - - return self.make_process(self.m.in2, prob, self.in2_stream, tgt, chk) - - def make_out_process(self, prob: float): - def tgt(x: int): - self.in_expected.append(x) - - def chk(x: int): - if self.out1_expected and x == self.out1_expected[0]: - self.out1_expected.popleft() - elif self.out2_expected and x == self.out2_expected[0]: - self.out2_expected.popleft() - else: - assert False, "%d not found in any of the queues" % x - - return self.make_process(self.m.out, prob, self.out_stream, tgt, chk) - - @parameterized.expand( - [ - ("fullcontention", 1, 1, 1), - ("highcontention", 0.5, 0.5, 0.75), - ("lowcontention", 0.1, 0.1, 0.5), - ] - ) - def test_calls(self, name, prob1, prob2, probout): - self.in1_stream = range(0, 100) - self.in2_stream = range(100, 200) - self.out_stream = range(200, 400) - self.in_expected = deque() - self.out1_expected = deque() - self.out2_expected = deque() - self.m = TransactionConflictTestCircuit(self.__class__.scheduler) - - with self.run_simulation(self.m, add_transaction_module=False) as sim: - sim.add_process(self.make_in1_process(prob1)) - sim.add_process(self.make_in2_process(prob2)) - sim.add_process(self.make_out_process(probout)) - - assert not self.in_expected - assert not self.out1_expected - assert not self.out2_expected - - -class SchedulingTestCircuit(Elaboratable): - def __init__(self): - self.r1 = Signal() - self.r2 = Signal() - self.t1 = Signal() - self.t2 = Signal() - - @abstractmethod - def elaborate(self, platform) -> TModule: - raise NotImplementedError - - -class PriorityTestCircuit(SchedulingTestCircuit): - def __init__(self, priority: Priority, unsatisfiable=False): - super().__init__() - self.priority = priority - self.unsatisfiable = unsatisfiable - - def make_relations(self, t1: Transaction | Method, t2: Transaction | Method): - t1.add_conflict(t2, self.priority) - if self.unsatisfiable: - t2.add_conflict(t1, self.priority) - - -class TransactionPriorityTestCircuit(PriorityTestCircuit): - def elaborate(self, platform): - m = TModule() - - transaction1 = Transaction() - transaction2 = Transaction() - - with transaction1.body(m, request=self.r1): - m.d.comb += self.t1.eq(1) - - with transaction2.body(m, request=self.r2): - m.d.comb += self.t2.eq(1) - - self.make_relations(transaction1, transaction2) - - return m - - -class MethodPriorityTestCircuit(PriorityTestCircuit): - def elaborate(self, platform): - m = TModule() - - method1 = Method() - method2 = Method() - - @def_method(m, method1, ready=self.r1) - def _(): - m.d.comb += self.t1.eq(1) - - @def_method(m, method2, ready=self.r2) - def _(): - m.d.comb += self.t2.eq(1) - - with Transaction().body(m): - method1(m) - - with Transaction().body(m): - method2(m) - - self.make_relations(method1, method2) - - return m - - -@parameterized_class( - ("name", "circuit"), [("transaction", TransactionPriorityTestCircuit), ("method", MethodPriorityTestCircuit)] -) -class TestTransactionPriorities(TestCaseWithSimulator): - circuit: type[PriorityTestCircuit] - - def setup_method(self): - random.seed(42) - - @parameterized.expand([(Priority.UNDEFINED,), (Priority.LEFT,), (Priority.RIGHT,)]) - def test_priorities(self, priority: Priority): - m = self.circuit(priority) - - def process(): - to_do = 5 * [(0, 1), (1, 0), (1, 1)] - random.shuffle(to_do) - for r1, r2 in to_do: - yield m.r1.eq(r1) - yield m.r2.eq(r2) - yield Settle() - assert (yield m.t1) != (yield m.t2) - if r1 == 1 and r2 == 1: - if priority == Priority.LEFT: - assert (yield m.t1) - if priority == Priority.RIGHT: - assert (yield m.t2) - - with self.run_simulation(m) as sim: - sim.add_process(process) - - @parameterized.expand([(Priority.UNDEFINED,), (Priority.LEFT,), (Priority.RIGHT,)]) - def test_unsatisfiable(self, priority: Priority): - m = self.circuit(priority, True) - - import graphlib - - if priority != Priority.UNDEFINED: - cm = pytest.raises(graphlib.CycleError) - else: - cm = contextlib.nullcontext() - - with cm: - with self.run_simulation(m): - pass - - -class NestedTransactionsTestCircuit(SchedulingTestCircuit): - def elaborate(self, platform): - m = TModule() - tm = TransactionModule(m) - - with tm.context(): - with Transaction().body(m, request=self.r1): - m.d.comb += self.t1.eq(1) - with Transaction().body(m, request=self.r2): - m.d.comb += self.t2.eq(1) - - return tm - - -class NestedMethodsTestCircuit(SchedulingTestCircuit): - def elaborate(self, platform): - m = TModule() - tm = TransactionModule(m) - - method1 = Method() - method2 = Method() - - @def_method(m, method1, ready=self.r1) - def _(): - m.d.comb += self.t1.eq(1) - - @def_method(m, method2, ready=self.r2) - def _(): - m.d.comb += self.t2.eq(1) - - with tm.context(): - with Transaction().body(m): - method1(m) - - with Transaction().body(m): - method2(m) - - return tm - - -@parameterized_class( - ("name", "circuit"), [("transaction", NestedTransactionsTestCircuit), ("method", NestedMethodsTestCircuit)] -) -class TestNested(TestCaseWithSimulator): - circuit: type[SchedulingTestCircuit] - - def setup_method(self): - random.seed(42) - - def test_scheduling(self): - m = self.circuit() - - def process(): - to_do = 5 * [(0, 1), (1, 0), (1, 1)] - random.shuffle(to_do) - for r1, r2 in to_do: - yield m.r1.eq(r1) - yield m.r2.eq(r2) - yield Tick() - assert (yield m.t1) == r1 - assert (yield m.t2) == r1 * r2 - - with self.run_simulation(m) as sim: - sim.add_process(process) - - -class ScheduleBeforeTestCircuit(SchedulingTestCircuit): - def elaborate(self, platform): - m = TModule() - tm = TransactionModule(m) - - method = Method() - - @def_method(m, method) - def _(): - pass - - with tm.context(): - with (t1 := Transaction()).body(m, request=self.r1): - method(m) - m.d.comb += self.t1.eq(1) - - with (t2 := Transaction()).body(m, request=self.r2 & t1.grant): - method(m) - m.d.comb += self.t2.eq(1) - - t1.schedule_before(t2) - - return tm - - -class TestScheduleBefore(TestCaseWithSimulator): - def setup_method(self): - random.seed(42) - - def test_schedule_before(self): - m = ScheduleBeforeTestCircuit() - - def process(): - to_do = 5 * [(0, 1), (1, 0), (1, 1)] - random.shuffle(to_do) - for r1, r2 in to_do: - yield m.r1.eq(r1) - yield m.r2.eq(r2) - yield Tick() - assert (yield m.t1) == r1 - assert not (yield m.t2) - - with self.run_simulation(m) as sim: - sim.add_process(process) - - -class SingleCallerTestCircuit(Elaboratable): - def elaborate(self, platform): - m = TModule() - - method = Method(single_caller=True) - - with Transaction().body(m): - method(m) - - with Transaction().body(m): - method(m) - - return m - - -class TestSingleCaller(TestCaseWithSimulator): - def test_single_caller(self): - m = SingleCallerTestCircuit() - - with pytest.raises(RuntimeError): - with self.run_simulation(m): - pass diff --git a/test/transactron/lib/__init__.py b/test/transactron/lib/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/transactron/lib/test_fifo.py b/test/transactron/lib/test_fifo.py deleted file mode 100644 index 39de8929a..000000000 --- a/test/transactron/lib/test_fifo.py +++ /dev/null @@ -1,79 +0,0 @@ -from amaranth import * -from amaranth.sim import Settle, Tick - -from transactron.lib import AdapterTrans, BasicFifo - -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout -from collections import deque -from parameterized import parameterized_class -import random - - -class BasicFifoTestCircuit(Elaboratable): - def __init__(self, depth): - self.depth = depth - - def elaborate(self, platform): - m = Module() - - m.submodules.fifo = self.fifo = BasicFifo(layout=data_layout(8), depth=self.depth) - - m.submodules.fifo_read = self.fifo_read = TestbenchIO(AdapterTrans(self.fifo.read)) - m.submodules.fifo_write = self.fifo_write = TestbenchIO(AdapterTrans(self.fifo.write)) - m.submodules.fifo_clear = self.fifo_clear = TestbenchIO(AdapterTrans(self.fifo.clear)) - - return m - - -@parameterized_class( - ("name", "depth"), - [ - ("notpower", 5), - ("power", 4), - ], -) -class TestBasicFifo(TestCaseWithSimulator): - depth: int - - def test_randomized(self): - fifoc = BasicFifoTestCircuit(depth=self.depth) - expq = deque() - - cycles = 256 - random.seed(42) - - self.done = False - - def source(): - for _ in range(cycles): - if random.randint(0, 1): - yield # random delay - - v = random.randint(0, (2**fifoc.fifo.width) - 1) - yield from fifoc.fifo_write.call(data=v) - expq.appendleft(v) - - if random.random() < 0.005: - yield from fifoc.fifo_clear.call() - yield Settle() - expq.clear() - - self.done = True - - def target(): - while not self.done or expq: - if random.randint(0, 1): - yield Tick() - - yield from fifoc.fifo_read.call_init() - yield Tick() - - v = yield from fifoc.fifo_read.call_result() - if v is not None: - assert v["data"] == expq.pop() - - yield from fifoc.fifo_read.disable() - - with self.run_simulation(fifoc) as sim: - sim.add_process(source) - sim.add_process(target) diff --git a/test/transactron/lib/test_transaction_lib.py b/test/transactron/lib/test_transaction_lib.py deleted file mode 100644 index 217897347..000000000 --- a/test/transactron/lib/test_transaction_lib.py +++ /dev/null @@ -1,804 +0,0 @@ -import pytest -from itertools import product -import random -from operator import and_ -from functools import reduce -from amaranth.sim import Settle, Tick -from typing import Optional, TypeAlias -from parameterized import parameterized -from collections import deque - -from amaranth import * -from transactron import * -from transactron.lib import * -from transactron.utils._typing import ModuleLike, MethodStruct, RecordDict -from transactron.utils import ModuleConnector -from transactron.testing import ( - SimpleTestCircuit, - TestCaseWithSimulator, - TestbenchIO, - data_layout, - def_method_mock, -) - - -class RevConnect(Elaboratable): - def __init__(self, layout: MethodLayout): - self.connect = Connect(rev_layout=layout) - self.read = self.connect.write - self.write = self.connect.read - - def elaborate(self, platform): - return self.connect - - -FIFO_Like: TypeAlias = FIFO | Forwarder | Connect | RevConnect | Pipe - - -class TestFifoBase(TestCaseWithSimulator): - def do_test_fifo( - self, fifo_class: type[FIFO_Like], writer_rand: int = 0, reader_rand: int = 0, fifo_kwargs: dict = {} - ): - iosize = 8 - - m = SimpleTestCircuit(fifo_class(data_layout(iosize), **fifo_kwargs)) - - random.seed(1337) - - def writer(): - for i in range(2**iosize): - yield from m.write.call(data=i) - yield from self.random_wait(writer_rand) - - def reader(): - for i in range(2**iosize): - assert (yield from m.read.call()) == {"data": i} - yield from self.random_wait(reader_rand) - - with self.run_simulation(m) as sim: - sim.add_process(reader) - sim.add_process(writer) - - -class TestFIFO(TestFifoBase): - @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) - def test_fifo(self, writer_rand, reader_rand): - self.do_test_fifo(FIFO, writer_rand=writer_rand, reader_rand=reader_rand, fifo_kwargs=dict(depth=4)) - - -class TestConnect(TestFifoBase): - @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) - def test_fifo(self, writer_rand, reader_rand): - self.do_test_fifo(Connect, writer_rand=writer_rand, reader_rand=reader_rand) - - @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) - def test_rev_fifo(self, writer_rand, reader_rand): - self.do_test_fifo(RevConnect, writer_rand=writer_rand, reader_rand=reader_rand) - - -class TestForwarder(TestFifoBase): - @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) - def test_fifo(self, writer_rand, reader_rand): - self.do_test_fifo(Forwarder, writer_rand=writer_rand, reader_rand=reader_rand) - - def test_forwarding(self): - iosize = 8 - - m = SimpleTestCircuit(Forwarder(data_layout(iosize))) - - def forward_check(x): - yield from m.read.call_init() - yield from m.write.call_init(data=x) - yield Settle() - assert (yield from m.read.call_result()) == {"data": x} - assert (yield from m.write.call_result()) is not None - yield Tick() - - def process(): - # test forwarding behavior - for x in range(4): - yield from forward_check(x) - - # load the overflow buffer - yield from m.read.disable() - yield from m.write.call_init(data=42) - yield Settle() - assert (yield from m.write.call_result()) is not None - yield Tick() - - # writes are not possible now - yield from m.write.call_init(data=84) - yield Settle() - assert (yield from m.write.call_result()) is None - yield Tick() - - # read from the overflow buffer, writes still blocked - yield from m.read.enable() - yield from m.write.call_init(data=111) - yield Settle() - assert (yield from m.read.call_result()) == {"data": 42} - assert (yield from m.write.call_result()) is None - yield Tick() - - # forwarding now works again - for x in range(4): - yield from forward_check(x) - - with self.run_simulation(m) as sim: - sim.add_process(process) - - -class TestPipe(TestFifoBase): - @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) - def test_fifo(self, writer_rand, reader_rand): - self.do_test_fifo(Pipe, writer_rand=writer_rand, reader_rand=reader_rand) - - def test_pipelining(self): - self.do_test_fifo(Pipe, writer_rand=0, reader_rand=0) - - -class TestMemoryBank(TestCaseWithSimulator): - test_conf = [(9, 3, 3, 3, 14), (16, 1, 1, 3, 15), (16, 1, 1, 1, 16), (12, 3, 1, 1, 17), (9, 0, 0, 0, 18)] - - @pytest.mark.parametrize("max_addr, writer_rand, reader_req_rand, reader_resp_rand, seed", test_conf) - @pytest.mark.parametrize("transparent", [False, True]) - @pytest.mark.parametrize("read_ports", [1, 2]) - @pytest.mark.parametrize("write_ports", [1, 2]) - def test_mem( - self, - max_addr: int, - writer_rand: int, - reader_req_rand: int, - reader_resp_rand: int, - seed: int, - transparent: bool, - read_ports: int, - write_ports: int, - ): - test_count = 200 - - data_width = 6 - m = SimpleTestCircuit( - MemoryBank( - data_layout=[("data", data_width)], - elem_count=max_addr, - transparent=transparent, - read_ports=read_ports, - write_ports=write_ports, - ) - ) - - data: list[int] = [0 for _ in range(max_addr)] - read_req_queues = [deque() for _ in range(read_ports)] - - random.seed(seed) - - def writer(i): - def process(): - for cycle in range(test_count): - d = random.randrange(2**data_width) - a = random.randrange(max_addr) - yield from m.writes[i].call(data=d, addr=a) - for _ in range(i + 2 if not transparent else i): - yield Settle() - data[a] = d - yield from self.random_wait(writer_rand) - - return process - - def reader_req(i): - def process(): - for cycle in range(test_count): - a = random.randrange(max_addr) - yield from m.read_reqs[i].call(addr=a) - for _ in range(1 if not transparent else write_ports + 2): - yield Settle() - d = data[a] - read_req_queues[i].append(d) - yield from self.random_wait(reader_req_rand) - - return process - - def reader_resp(i): - def process(): - for cycle in range(test_count): - for _ in range(write_ports + 3): - yield Settle() - while not read_req_queues[i]: - yield from self.random_wait(reader_resp_rand or 1, min_cycle_cnt=1) - for _ in range(write_ports + 3): - yield Settle() - d = read_req_queues[i].popleft() - assert (yield from m.read_resps[i].call()) == {"data": d} - yield from self.random_wait(reader_resp_rand) - - return process - - pipeline_test = writer_rand == 0 and reader_req_rand == 0 and reader_resp_rand == 0 - max_cycles = test_count + 2 if pipeline_test else 100000 - - with self.run_simulation(m, max_cycles=max_cycles) as sim: - for i in range(read_ports): - sim.add_process(reader_req(i)) - sim.add_process(reader_resp(i)) - for i in range(write_ports): - sim.add_process(writer(i)) - - -class TestAsyncMemoryBank(TestCaseWithSimulator): - @pytest.mark.parametrize( - "max_addr, writer_rand, reader_rand, seed", [(9, 3, 3, 14), (16, 1, 1, 15), (16, 1, 1, 16), (12, 3, 1, 17)] - ) - @pytest.mark.parametrize("read_ports", [1, 2]) - @pytest.mark.parametrize("write_ports", [1, 2]) - def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int, read_ports: int, write_ports: int): - test_count = 200 - - data_width = 6 - m = SimpleTestCircuit( - AsyncMemoryBank( - data_layout=[("data", data_width)], elem_count=max_addr, read_ports=read_ports, write_ports=write_ports - ) - ) - - data: list[int] = list(0 for i in range(max_addr)) - - random.seed(seed) - - def writer(i): - def process(): - for cycle in range(test_count): - d = random.randrange(2**data_width) - a = random.randrange(max_addr) - yield from m.writes[i].call(data=d, addr=a) - for _ in range(i + 2): - yield Settle() - data[a] = d - yield from self.random_wait(writer_rand, min_cycle_cnt=1) - - return process - - def reader(i): - def process(): - for cycle in range(test_count): - a = random.randrange(max_addr) - d = yield from m.reads[i].call(addr=a) - for _ in range(1): - yield Settle() - expected_d = data[a] - assert d["data"] == expected_d - yield from self.random_wait(reader_rand, min_cycle_cnt=1) - - return process - - with self.run_simulation(m) as sim: - for i in range(read_ports): - sim.add_process(reader(i)) - for i in range(write_ports): - sim.add_process(writer(i)) - - -class ManyToOneConnectTransTestCircuit(Elaboratable): - def __init__(self, count: int, lay: MethodLayout): - self.count = count - self.lay = lay - self.inputs = [] - - def elaborate(self, platform): - m = TModule() - - get_results = [] - for i in range(self.count): - input = TestbenchIO(Adapter(o=self.lay)) - get_results.append(input.adapter.iface) - m.submodules[f"input_{i}"] = input - self.inputs.append(input) - - # Create ManyToOneConnectTrans, which will serialize results from different inputs - output = TestbenchIO(Adapter(i=self.lay)) - m.submodules.output = output - self.output = output - m.submodules.fu_arbitration = ManyToOneConnectTrans(get_results=get_results, put_result=output.adapter.iface) - - return m - - -class TestManyToOneConnectTrans(TestCaseWithSimulator): - def initialize(self): - f1_size = 14 - f2_size = 3 - self.lay = [("field1", f1_size), ("field2", f2_size)] - - self.m = ManyToOneConnectTransTestCircuit(self.count, self.lay) - random.seed(14) - - self.inputs = [] - # Create list with info if we processed all data from inputs - self.producer_end = [False for i in range(self.count)] - self.expected_output = {} - self.max_wait = 4 - - # Prepare random results for inputs - for i in range(self.count): - data = [] - input_size = random.randint(20, 30) - for j in range(input_size): - t = ( - random.randint(0, 2**f1_size), - random.randint(0, 2**f2_size), - ) - data.append(t) - if t in self.expected_output: - self.expected_output[t] += 1 - else: - self.expected_output[t] = 1 - self.inputs.append(data) - - def generate_producer(self, i: int): - """ - This is an helper function, which generates a producer process, - which will simulate an FU. Producer will insert in random intervals new - results to its output FIFO. This records will be next serialized by FUArbiter. - """ - - def producer(): - inputs = self.inputs[i] - for field1, field2 in inputs: - io: TestbenchIO = self.m.inputs[i] - yield from io.call_init(field1=field1, field2=field2) - yield from self.random_wait(self.max_wait) - self.producer_end[i] = True - - return producer - - def consumer(self): - while reduce(and_, self.producer_end, True): - result = yield from self.m.output.call_do() - - assert result is not None - - t = (result["field1"], result["field2"]) - assert t in self.expected_output - if self.expected_output[t] == 1: - del self.expected_output[t] - else: - self.expected_output[t] -= 1 - yield from self.random_wait(self.max_wait) - - def test_one_out(self): - self.count = 1 - self.initialize() - with self.run_simulation(self.m) as sim: - sim.add_process(self.consumer) - for i in range(self.count): - sim.add_process(self.generate_producer(i)) - - def test_many_out(self): - self.count = 4 - self.initialize() - with self.run_simulation(self.m) as sim: - sim.add_process(self.consumer) - for i in range(self.count): - sim.add_process(self.generate_producer(i)) - - -class MethodMapTestCircuit(Elaboratable): - def __init__(self, iosize: int, use_methods: bool, use_dicts: bool): - self.iosize = iosize - self.use_methods = use_methods - self.use_dicts = use_dicts - - def elaborate(self, platform): - m = TModule() - - layout = data_layout(self.iosize) - - def itransform_rec(m: ModuleLike, v: MethodStruct) -> MethodStruct: - s = Signal.like(v) - m.d.comb += s.data.eq(v.data + 1) - return s - - def otransform_rec(m: ModuleLike, v: MethodStruct) -> MethodStruct: - s = Signal.like(v) - m.d.comb += s.data.eq(v.data - 1) - return s - - def itransform_dict(_, v: MethodStruct) -> RecordDict: - return {"data": v.data + 1} - - def otransform_dict(_, v: MethodStruct) -> RecordDict: - return {"data": v.data - 1} - - if self.use_dicts: - itransform = itransform_dict - otransform = otransform_dict - else: - itransform = itransform_rec - otransform = otransform_rec - - m.submodules.target = self.target = TestbenchIO(Adapter(i=layout, o=layout)) - - if self.use_methods: - imeth = Method(i=layout, o=layout) - ometh = Method(i=layout, o=layout) - - @def_method(m, imeth) - def _(arg: MethodStruct): - return itransform(m, arg) - - @def_method(m, ometh) - def _(arg: MethodStruct): - return otransform(m, arg) - - trans = MethodMap(self.target.adapter.iface, i_transform=(layout, imeth), o_transform=(layout, ometh)) - else: - trans = MethodMap( - self.target.adapter.iface, - i_transform=(layout, itransform), - o_transform=(layout, otransform), - ) - - m.submodules.source = self.source = TestbenchIO(AdapterTrans(trans.use(m))) - - return m - - -class TestMethodTransformer(TestCaseWithSimulator): - m: MethodMapTestCircuit - - def source(self): - for i in range(2**self.m.iosize): - v = yield from self.m.source.call(data=i) - i1 = (i + 1) & ((1 << self.m.iosize) - 1) - assert v["data"] == (((i1 << 1) | (i1 >> (self.m.iosize - 1))) - 1) & ((1 << self.m.iosize) - 1) - - @def_method_mock(lambda self: self.m.target) - def target(self, data): - return {"data": (data << 1) | (data >> (self.m.iosize - 1))} - - def test_method_transformer(self): - self.m = MethodMapTestCircuit(4, False, False) - with self.run_simulation(self.m) as sim: - sim.add_process(self.source) - sim.add_process(self.target) - - def test_method_transformer_dicts(self): - self.m = MethodMapTestCircuit(4, False, True) - with self.run_simulation(self.m) as sim: - sim.add_process(self.source) - - def test_method_transformer_with_methods(self): - self.m = MethodMapTestCircuit(4, True, True) - with self.run_simulation(self.m) as sim: - sim.add_process(self.source) - - -class TestMethodFilter(TestCaseWithSimulator): - def initialize(self): - self.iosize = 4 - self.layout = data_layout(self.iosize) - self.target = TestbenchIO(Adapter(i=self.layout, o=self.layout)) - self.cmeth = TestbenchIO(Adapter(i=self.layout, o=data_layout(1))) - - def source(self): - for i in range(2**self.iosize): - v = yield from self.tc.method.call(data=i) - if i & 1: - assert v["data"] == (i + 1) & ((1 << self.iosize) - 1) - else: - assert v["data"] == 0 - - @def_method_mock(lambda self: self.target, sched_prio=2) - def target_mock(self, data): - return {"data": data + 1} - - @def_method_mock(lambda self: self.cmeth, sched_prio=1) - def cmeth_mock(self, data): - return {"data": data % 2} - - @parameterized.expand([(True,), (False,)]) - def test_method_filter_with_methods(self, use_condition): - self.initialize() - self.tc = SimpleTestCircuit( - MethodFilter(self.target.adapter.iface, self.cmeth.adapter.iface, use_condition=use_condition) - ) - m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) - with self.run_simulation(m) as sim: - sim.add_process(self.source) - - @parameterized.expand([(True,), (False,)]) - def test_method_filter(self, use_condition): - self.initialize() - - def condition(_, v): - return v.data[0] - - self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition, use_condition=use_condition)) - m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) - with self.run_simulation(m) as sim: - sim.add_process(self.source) - - -class MethodProductTestCircuit(Elaboratable): - def __init__(self, iosize: int, targets: int, add_combiner: bool): - self.iosize = iosize - self.targets = targets - self.add_combiner = add_combiner - self.target: list[TestbenchIO] = [] - - def elaborate(self, platform): - m = TModule() - - layout = data_layout(self.iosize) - - methods = [] - - for k in range(self.targets): - tgt = TestbenchIO(Adapter(i=layout, o=layout)) - methods.append(tgt.adapter.iface) - self.target.append(tgt) - m.submodules += tgt - - combiner = None - if self.add_combiner: - combiner = (layout, lambda _, vs: {"data": sum(x.data for x in vs)}) - - product = MethodProduct(methods, combiner) - - m.submodules.method = self.method = TestbenchIO(AdapterTrans(product.use(m))) - - return m - - -class TestMethodProduct(TestCaseWithSimulator): - @parameterized.expand([(1, False), (2, False), (5, True)]) - def test_method_product(self, targets: int, add_combiner: bool): - random.seed(14) - - iosize = 8 - m = MethodProductTestCircuit(iosize, targets, add_combiner) - - method_en = [False] * targets - - def target_process(k: int): - @def_method_mock(lambda: m.target[k], enable=lambda: method_en[k]) - def process(data): - return {"data": data + k} - - return process - - def method_process(): - # if any of the target methods is not enabled, call does not succeed - for i in range(2**targets - 1): - for k in range(targets): - method_en[k] = bool(i & (1 << k)) - - yield Tick() - assert (yield from m.method.call_try(data=0)) is None - - # otherwise, the call succeeds - for k in range(targets): - method_en[k] = True - yield Tick() - - data = random.randint(0, (1 << iosize) - 1) - val = (yield from m.method.call(data=data))["data"] - if add_combiner: - assert val == (targets * data + (targets - 1) * targets // 2) & ((1 << iosize) - 1) - else: - assert val == data - - with self.run_simulation(m) as sim: - sim.add_process(method_process) - for k in range(targets): - sim.add_process(target_process(k)) - - -class TestSerializer(TestCaseWithSimulator): - def setup_method(self): - self.test_count = 100 - - self.port_count = 2 - self.data_width = 5 - - self.requestor_rand = 4 - - layout = [("field", self.data_width)] - - self.req_method = TestbenchIO(Adapter(i=layout)) - self.resp_method = TestbenchIO(Adapter(o=layout)) - - self.test_circuit = SimpleTestCircuit( - Serializer( - port_count=self.port_count, - serialized_req_method=self.req_method.adapter.iface, - serialized_resp_method=self.resp_method.adapter.iface, - ) - ) - self.m = ModuleConnector( - test_circuit=self.test_circuit, req_method=self.req_method, resp_method=self.resp_method - ) - - random.seed(14) - - self.serialized_data = deque() - self.port_data = [deque() for _ in range(self.port_count)] - - self.got_request = False - - @def_method_mock(lambda self: self.req_method, enable=lambda self: not self.got_request) - def serial_req_mock(self, field): - self.serialized_data.append(field) - self.got_request = True - - @def_method_mock(lambda self: self.resp_method, enable=lambda self: self.got_request) - def serial_resp_mock(self): - self.got_request = False - return {"field": self.serialized_data[-1]} - - def requestor(self, i: int): - def f(): - for _ in range(self.test_count): - d = random.randrange(2**self.data_width) - yield from self.test_circuit.serialize_in[i].call(field=d) - self.port_data[i].append(d) - yield from self.random_wait(self.requestor_rand, min_cycle_cnt=1) - - return f - - def responder(self, i: int): - def f(): - for _ in range(self.test_count): - data_out = yield from self.test_circuit.serialize_out[i].call() - assert self.port_data[i].popleft() == data_out["field"] - yield from self.random_wait(self.requestor_rand, min_cycle_cnt=1) - - return f - - def test_serial(self): - with self.run_simulation(self.m) as sim: - for i in range(self.port_count): - sim.add_process(self.requestor(i)) - sim.add_process(self.responder(i)) - - -class TestMethodTryProduct(TestCaseWithSimulator): - @parameterized.expand([(1, False), (2, False), (5, True)]) - def test_method_try_product(self, targets: int, add_combiner: bool): - random.seed(14) - - iosize = 8 - m = MethodTryProductTestCircuit(iosize, targets, add_combiner) - - method_en = [False] * targets - - def target_process(k: int): - @def_method_mock(lambda: m.target[k], enable=lambda: method_en[k]) - def process(data): - return {"data": data + k} - - return process - - def method_process(): - for i in range(2**targets): - for k in range(targets): - method_en[k] = bool(i & (1 << k)) - - active_targets = sum(method_en) - - yield Tick() - - data = random.randint(0, (1 << iosize) - 1) - val = yield from m.method.call(data=data) - if add_combiner: - adds = sum(k * method_en[k] for k in range(targets)) - assert val == {"data": (active_targets * data + adds) & ((1 << iosize) - 1)} - else: - assert val == {} - - with self.run_simulation(m) as sim: - sim.add_process(method_process) - for k in range(targets): - sim.add_process(target_process(k)) - - -class MethodTryProductTestCircuit(Elaboratable): - def __init__(self, iosize: int, targets: int, add_combiner: bool): - self.iosize = iosize - self.targets = targets - self.add_combiner = add_combiner - self.target: list[TestbenchIO] = [] - - def elaborate(self, platform): - m = TModule() - - layout = data_layout(self.iosize) - - methods = [] - - for k in range(self.targets): - tgt = TestbenchIO(Adapter(i=layout, o=layout)) - methods.append(tgt.adapter.iface) - self.target.append(tgt) - m.submodules += tgt - - combiner = None - if self.add_combiner: - combiner = (layout, lambda _, vs: {"data": sum(Mux(s, r, 0) for (s, r) in vs)}) - - product = MethodTryProduct(methods, combiner) - - m.submodules.method = self.method = TestbenchIO(AdapterTrans(product.use(m))) - - return m - - -class ConditionTestCircuit(Elaboratable): - def __init__(self, target: Method, *, nonblocking: bool, priority: bool, catchall: bool): - self.target = target - self.source = Method(i=[("cond1", 1), ("cond2", 1), ("cond3", 1)], single_caller=True) - self.nonblocking = nonblocking - self.priority = priority - self.catchall = catchall - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.source) - def _(cond1, cond2, cond3): - with condition(m, nonblocking=self.nonblocking, priority=self.priority) as branch: - with branch(cond1): - self.target(m, cond=1) - with branch(cond2): - self.target(m, cond=2) - with branch(cond3): - self.target(m, cond=3) - if self.catchall: - with branch(): - self.target(m, cond=0) - - return m - - -class TestCondition(TestCaseWithSimulator): - @pytest.mark.parametrize("nonblocking", [False, True]) - @pytest.mark.parametrize("priority", [False, True]) - @pytest.mark.parametrize("catchall", [False, True]) - def test_condition(self, nonblocking: bool, priority: bool, catchall: bool): - target = TestbenchIO(Adapter(i=[("cond", 2)])) - - circ = SimpleTestCircuit( - ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall) - ) - m = ModuleConnector(test_circuit=circ, target=target) - - selection: Optional[int] - - @def_method_mock(lambda: target) - def target_process(cond): - nonlocal selection - selection = cond - - def process(): - nonlocal selection - for c1, c2, c3 in product([0, 1], [0, 1], [0, 1]): - selection = None - res = yield from circ.source.call_try(cond1=c1, cond2=c2, cond3=c3) - - if catchall or nonblocking: - assert res is not None - - if res is None: - assert selection is None - assert not catchall or nonblocking - assert (c1, c2, c3) == (0, 0, 0) - elif selection is None: - assert nonblocking - assert (c1, c2, c3) == (0, 0, 0) - elif priority: - assert selection == c1 + 2 * c2 * (1 - c1) + 3 * c3 * (1 - c2) * (1 - c1) - else: - assert selection in [c1, 2 * c2, 3 * c3] - - with self.run_simulation(m) as sim: - sim.add_process(process) diff --git a/test/transactron/test_adapter.py b/test/transactron/test_adapter.py deleted file mode 100644 index a5fa73264..000000000 --- a/test/transactron/test_adapter.py +++ /dev/null @@ -1,62 +0,0 @@ -from amaranth import * - -from transactron import Method, def_method, TModule - - -from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, ModuleConnector - - -class Echo(Elaboratable): - def __init__(self): - self.data_bits = 8 - - self.layout_in = data_layout(self.data_bits) - self.layout_out = data_layout(self.data_bits) - - self.action = Method(i=self.layout_in, o=self.layout_out) - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.action, ready=C(1)) - def _(arg): - return arg - - return m - - -class Consumer(Elaboratable): - def __init__(self): - self.data_bits = 8 - - self.layout_in = data_layout(self.data_bits) - self.layout_out = [] - - self.action = Method(i=self.layout_in, o=self.layout_out) - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.action, ready=C(1)) - def _(arg): - return None - - return m - - -class TestAdapterTrans(TestCaseWithSimulator): - def proc(self): - for _ in range(3): - # this would previously timeout if the output layout was empty (as is in this case) - yield from self.consumer.action.call(data=0) - for expected in [4, 1, 0]: - obtained = (yield from self.echo.action.call(data=expected))["data"] - assert expected == obtained - - def test_single(self): - self.echo = SimpleTestCircuit(Echo()) - self.consumer = SimpleTestCircuit(Consumer()) - self.m = ModuleConnector(echo=self.echo, consumer=self.consumer) - - with self.run_simulation(self.m, max_cycles=100) as sim: - sim.add_process(self.proc) diff --git a/test/transactron/test_assign.py b/test/transactron/test_assign.py deleted file mode 100644 index 7398570fa..000000000 --- a/test/transactron/test_assign.py +++ /dev/null @@ -1,160 +0,0 @@ -import pytest -from typing import Callable -from amaranth import * -from amaranth.lib import data -from amaranth.lib.enum import Enum -from amaranth.hdl._ast import ArrayProxy, SwitchValue, Slice - -from transactron.utils._typing import MethodLayout -from transactron.utils import AssignType, assign -from transactron.utils.assign import AssignArg, AssignFields - -from unittest import TestCase -from parameterized import parameterized_class, parameterized - - -class ExampleEnum(Enum, shape=1): - ZERO = 0 - ONE = 1 - - -def with_reversed(pairs: list[tuple[str, str]]): - return pairs + [(b, a) for (a, b) in pairs] - - -layout_a = [("a", 1)] -layout_ab = [("a", 1), ("b", 2)] -layout_ac = [("a", 1), ("c", 3)] -layout_a_alt = [("a", 2)] -layout_a_enum = [("a", ExampleEnum)] - -# Defines functions build, wrap, extr used in TestAssign -params_funs = { - "normal": (lambda mk, lay: mk(lay), lambda x: x, lambda r: r), - "rec": (lambda mk, lay: mk([("x", lay)]), lambda x: {"x": x}, lambda r: r.x), - "dict": (lambda mk, lay: {"x": mk(lay)}, lambda x: {"x": x}, lambda r: r["x"]), - "list": (lambda mk, lay: [mk(lay)], lambda x: {0: x}, lambda r: r[0]), - "union": ( - lambda mk, lay: Signal(data.UnionLayout({"x": reclayout2datalayout(lay)})), - lambda x: {"x": x}, - lambda r: r.x, - ), - "array": (lambda mk, lay: Signal(data.ArrayLayout(reclayout2datalayout(lay), 1)), lambda x: {0: x}, lambda r: r[0]), -} - - -params_pairs = [(k, k) for k in params_funs if k != "union"] + with_reversed( - [("rec", "dict"), ("list", "array"), ("union", "dict")] -) - - -def mkproxy(layout): - arr = Array([Signal(reclayout2datalayout(layout)) for _ in range(4)]) - sig = Signal(2) - return arr[sig] - - -def reclayout2datalayout(layout): - if not isinstance(layout, list): - return layout - return data.StructLayout({k: reclayout2datalayout(lay) for k, lay in layout}) - - -def mkstruct(layout): - return Signal(reclayout2datalayout(layout)) - - -params_mk = [ - ("proxy", mkproxy), - ("struct", mkstruct), -] - - -@parameterized_class( - ["name", "buildl", "wrapl", "extrl", "buildr", "wrapr", "extrr", "mk"], - [ - (f"{nl}_{nr}_{c}", *map(staticmethod, params_funs[nl] + params_funs[nr] + (m,))) - for nl, nr in params_pairs - for c, m in params_mk - ], -) -class TestAssign(TestCase): - # constructs `assign` arguments (views, proxies, dicts) which have an "inner" and "outer" part - # parameterized with a constructor and a layout of the inner part - buildl: Callable[[Callable[[MethodLayout], AssignArg], MethodLayout], AssignArg] - buildr: Callable[[Callable[[MethodLayout], AssignArg], MethodLayout], AssignArg] - # constructs field specifications for `assign`, takes field specifications for the inner part - wrapl: Callable[[AssignFields], AssignFields] - wrapr: Callable[[AssignFields], AssignFields] - # extracts the inner part of the structure - extrl: Callable[[AssignArg], ArrayProxy] - extrr: Callable[[AssignArg], ArrayProxy] - # constructor, takes a layout - mk: Callable[[MethodLayout], AssignArg] - - def test_wraps_eq(self): - assert self.wrapl({}) == self.wrapr({}) - - def test_rhs_exception(self): - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_ab), fields=AssignType.RHS)) - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_ab), self.buildr(self.mk, layout_ac), fields=AssignType.RHS)) - - def test_all_exception(self): - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_ab), fields=AssignType.ALL)) - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_ab), self.buildr(self.mk, layout_a), fields=AssignType.ALL)) - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_ab), self.buildr(self.mk, layout_ac), fields=AssignType.ALL)) - - def test_missing_exception(self): - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_ab), fields=self.wrapl({"b"}))) - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_ab), self.buildr(self.mk, layout_a), fields=self.wrapl({"b"}))) - with pytest.raises(KeyError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_a), fields=self.wrapl({"b"}))) - - def test_wrong_bits(self): - with pytest.raises(ValueError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_a_alt))) - if self.mk != mkproxy: # Arrays are troublesome and defeat some checks - with pytest.raises(ValueError): - list(assign(self.buildl(self.mk, layout_a), self.buildr(self.mk, layout_a_enum))) - - @parameterized.expand( - [ - ("lhs", layout_a, layout_ab, AssignType.LHS), - ("rhs", layout_ab, layout_a, AssignType.RHS), - ("all", layout_a, layout_a, AssignType.ALL), - ("common", layout_ab, layout_ac, AssignType.COMMON), - ("set", layout_ab, layout_ab, {"a"}), - ("list", layout_ab, layout_ab, ["a", "a"]), - ] - ) - def test_assign_a(self, name, layout1: MethodLayout, layout2: MethodLayout, atype: AssignType): - lhs = self.buildl(self.mk, layout1) - rhs = self.buildr(self.mk, layout2) - alist = list(assign(lhs, rhs, fields=self.wrapl(atype))) - assert len(alist) == 1 - self.assertIs_AP(alist[0].lhs, self.extrl(lhs).a) - self.assertIs_AP(alist[0].rhs, self.extrr(rhs).a) - - def assertIs_AP(self, expr1, expr2): # noqa: N802 - expr1 = Value.cast(expr1) - expr2 = Value.cast(expr2) - if isinstance(expr1, SwitchValue) and isinstance(expr2, SwitchValue): - # new proxies are created on each index, structural equality is needed - self.assertIs(expr1.test, expr2.test) - assert len(expr1.cases) == len(expr2.cases) - for (px, x), (py, y) in zip(expr1.cases, expr2.cases): - self.assertEqual(px, py) - self.assertIs_AP(x, y) - elif isinstance(expr1, Slice) and isinstance(expr2, Slice): - self.assertIs_AP(expr1.value, expr2.value) - assert expr1.start == expr2.start - assert expr1.stop == expr2.stop - else: - self.assertIs(expr1, expr2) diff --git a/test/transactron/test_branches.py b/test/transactron/test_branches.py deleted file mode 100644 index bfb1d5842..000000000 --- a/test/transactron/test_branches.py +++ /dev/null @@ -1,99 +0,0 @@ -from amaranth import * -from itertools import product -from transactron.core import ( - TModule, - Method, - Transaction, - TransactionManager, - TransactionModule, - def_method, -) -from transactron.core.tmodule import CtrlPath -from transactron.core.manager import MethodMap -from unittest import TestCase -from transactron.testing import TestCaseWithSimulator -from transactron.utils.dependencies import DependencyContext - - -class TestExclusivePath(TestCase): - def test_exclusive_path(self): - m = TModule() - m._MustUse__silence = True # type: ignore - - with m.If(0): - cp0 = m.ctrl_path - with m.Switch(3): - with m.Case(0): - cp0a0 = m.ctrl_path - with m.Case(1): - cp0a1 = m.ctrl_path - with m.Default(): - cp0a2 = m.ctrl_path - with m.If(1): - cp0b0 = m.ctrl_path - with m.Else(): - cp0b1 = m.ctrl_path - with m.Elif(1): - cp1 = m.ctrl_path - with m.FSM(): - with m.State("start"): - cp10 = m.ctrl_path - with m.State("next"): - cp11 = m.ctrl_path - with m.Else(): - cp2 = m.ctrl_path - - def mutually_exclusive(*cps: CtrlPath): - return all(cpa.exclusive_with(cpb) for i, cpa in enumerate(cps) for cpb in cps[i + 1 :]) - - def pairwise_exclusive(cps1: list[CtrlPath], cps2: list[CtrlPath]): - return all(cpa.exclusive_with(cpb) for cpa, cpb in product(cps1, cps2)) - - def pairwise_not_exclusive(cps1: list[CtrlPath], cps2: list[CtrlPath]): - return all(not cpa.exclusive_with(cpb) for cpa, cpb in product(cps1, cps2)) - - assert mutually_exclusive(cp0, cp1, cp2) - assert mutually_exclusive(cp0a0, cp0a1, cp0a2) - assert mutually_exclusive(cp0b0, cp0b1) - assert mutually_exclusive(cp10, cp11) - assert pairwise_exclusive([cp0, cp0a0, cp0a1, cp0a2, cp0b0, cp0b1], [cp1, cp10, cp11]) - assert pairwise_not_exclusive([cp0, cp0a0, cp0a1, cp0a2], [cp0, cp0b0, cp0b1]) - - -class ExclusiveConflictRemovalCircuit(Elaboratable): - def __init__(self): - self.sel = Signal() - - def elaborate(self, platform): - m = TModule() - - called_method = Method(i=[], o=[]) - - @def_method(m, called_method) - def _(): - pass - - with m.If(self.sel): - with Transaction().body(m): - called_method(m) - with m.Else(): - with Transaction().body(m): - called_method(m) - - return m - - -class TestExclusiveConflictRemoval(TestCaseWithSimulator): - def test_conflict_removal(self): - circ = ExclusiveConflictRemovalCircuit() - - tm = TransactionManager() - dut = TransactionModule(circ, DependencyContext.get(), tm) - - with self.run_simulation(dut, add_transaction_module=False): - pass - - cgr, _ = tm._conflict_graph(MethodMap(tm.transactions)) - - for s in cgr.values(): - assert not s diff --git a/test/transactron/test_connectors.py b/test/transactron/test_connectors.py deleted file mode 100644 index ac15a9f9d..000000000 --- a/test/transactron/test_connectors.py +++ /dev/null @@ -1,46 +0,0 @@ -import random -from parameterized import parameterized_class - -from amaranth.sim import Settle, Tick - -from transactron.lib import StableSelectingNetwork -from transactron.testing import TestCaseWithSimulator - - -@parameterized_class( - ("n"), - [(2,), (3,), (7,), (8,)], -) -class TestStableSelectingNetwork(TestCaseWithSimulator): - n: int - - def test(self): - m = StableSelectingNetwork(self.n, [("data", 8)]) - - random.seed(42) - - def process(): - for _ in range(100): - inputs = [random.randrange(2**8) for _ in range(self.n)] - valids = [random.randrange(2) for _ in range(self.n)] - total = sum(valids) - - expected_output_prefix = [] - for i in range(self.n): - yield m.valids[i].eq(valids[i]) - yield m.inputs[i].data.eq(inputs[i]) - - if valids[i]: - expected_output_prefix.append(inputs[i]) - - yield Settle() - - for i in range(total): - out = yield m.outputs[i].data - assert out == expected_output_prefix[i] - - assert (yield m.output_cnt) == total - yield Tick() - - with self.run_simulation(m) as sim: - sim.add_process(process) diff --git a/test/transactron/test_methods.py b/test/transactron/test_methods.py deleted file mode 100644 index e03ae5f17..000000000 --- a/test/transactron/test_methods.py +++ /dev/null @@ -1,812 +0,0 @@ -from collections.abc import Callable, Sequence -import pytest -import random -from amaranth import * -from amaranth.sim import * - -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout - -from transactron import * -from transactron.testing.infrastructure import SimpleTestCircuit -from transactron.utils import MethodStruct -from transactron.lib import * - -from parameterized import parameterized - -from unittest import TestCase - -from transactron.utils.assign import AssignArg - - -class TestDefMethod(TestCaseWithSimulator): - class CircuitTestModule(Elaboratable): - def __init__(self, method_definition): - self.method = Method( - i=[("foo1", 3), ("foo2", [("bar1", 4), ("bar2", 6)])], - o=[("foo1", 3), ("foo2", [("bar1", 4), ("bar2", 6)])], - ) - - self.method_definition = method_definition - - def elaborate(self, platform): - m = TModule() - m._MustUse__silence = True # type: ignore - - def_method(m, self.method)(self.method_definition) - - return m - - def do_test_definition(self, definer): - with self.run_simulation(TestDefMethod.CircuitTestModule(definer)): - pass - - def test_fields_valid1(self): - def definition(arg): - return {"foo1": Signal(3), "foo2": {"bar1": Signal(4), "bar2": Signal(6)}} - - self.do_test_definition(definition) - - def test_fields_valid2(self): - rec = Signal(from_method_layout([("bar1", 4), ("bar2", 6)])) - - def definition(arg): - return {"foo1": Signal(3), "foo2": rec} - - self.do_test_definition(definition) - - def test_fields_valid3(self): - def definition(arg): - return arg - - self.do_test_definition(definition) - - def test_fields_valid4(self): - def definition(arg: MethodStruct): - return arg - - self.do_test_definition(definition) - - def test_fields_valid5(self): - def definition(**arg): - return arg - - self.do_test_definition(definition) - - def test_fields_valid6(self): - def definition(foo1, foo2): - return {"foo1": foo1, "foo2": foo2} - - self.do_test_definition(definition) - - def test_fields_valid7(self): - def definition(foo1, **arg): - return {"foo1": foo1, "foo2": arg["foo2"]} - - self.do_test_definition(definition) - - def test_fields_invalid1(self): - def definition(arg): - return {"foo1": Signal(3), "baz": Signal(4)} - - with pytest.raises(KeyError): - self.do_test_definition(definition) - - def test_fields_invalid2(self): - def definition(arg): - return {"foo1": Signal(3)} - - with pytest.raises(KeyError): - self.do_test_definition(definition) - - def test_fields_invalid3(self): - def definition(arg): - return {"foo1": {"baz1": Signal(), "baz2": Signal()}, "foo2": {"bar1": Signal(4), "bar2": Signal(6)}} - - with pytest.raises(TypeError): - self.do_test_definition(definition) - - def test_fields_invalid4(self): - def definition(arg: Value): - return arg - - with pytest.raises(TypeError): - self.do_test_definition(definition) - - def test_fields_invalid5(self): - def definition(foo): - return foo - - with pytest.raises(TypeError): - self.do_test_definition(definition) - - def test_fields_invalid6(self): - def definition(foo1): - return {"foo1": foo1, "foo2": {"bar1": Signal(4), "bar2": Signal(6)}} - - with pytest.raises(TypeError): - self.do_test_definition(definition) - - -class TestDefMethods(TestCaseWithSimulator): - class CircuitTestModule(Elaboratable): - def __init__(self, method_definition): - self.methods = [ - Method( - i=[("foo", 3)], - o=[("foo", 3)], - ) - for _ in range(4) - ] - - self.method_definition = method_definition - - def elaborate(self, platform): - m = TModule() - m._MustUse__silence = True # type: ignore - - def_methods(m, self.methods)(self.method_definition) - - return m - - def test_basic_methods(self): - def definition(idx: int, foo: Value): - return {"foo": foo + idx} - - circuit = SimpleTestCircuit(TestDefMethods.CircuitTestModule(definition)) - - def test_process(): - for k, method in enumerate(circuit.methods): - val = random.randrange(0, 2**3) - ret = yield from method.call(foo=val) - assert ret["foo"] == (val + k) % 2**3 - - with self.run_simulation(circuit) as sim: - sim.add_process(test_process) - - -class AdapterCircuit(Elaboratable): - def __init__(self, module, methods): - self.module = module - self.methods = methods - - def elaborate(self, platform): - m = TModule() - - m.submodules += self.module - for method in self.methods: - m.submodules += AdapterTrans(method) - - return m - - -class TestInvalidMethods(TestCase): - def assert_re(self, msg, m): - with pytest.raises(RuntimeError, match=msg): - Fragment.get(TransactionModule(m), platform=None) - - def test_twice(self): - class Twice(Elaboratable): - def __init__(self): - self.meth1 = Method() - self.meth2 = Method() - - def elaborate(self, platform): - m = TModule() - m._MustUse__silence = True # type: ignore - - with self.meth1.body(m): - pass - - with self.meth2.body(m): - self.meth1(m) - self.meth1(m) - - return m - - self.assert_re("called twice", Twice()) - - def test_twice_cond(self): - class Twice(Elaboratable): - def __init__(self): - self.meth1 = Method() - self.meth2 = Method() - - def elaborate(self, platform): - m = TModule() - m._MustUse__silence = True # type: ignore - - with self.meth1.body(m): - pass - - with self.meth2.body(m): - with m.If(1): - self.meth1(m) - with m.Else(): - self.meth1(m) - - return m - - Fragment.get(TransactionModule(Twice()), platform=None) - - def test_diamond(self): - class Diamond(Elaboratable): - def __init__(self): - self.meth1 = Method() - self.meth2 = Method() - self.meth3 = Method() - self.meth4 = Method() - - def elaborate(self, platform): - m = TModule() - - with self.meth1.body(m): - pass - - with self.meth2.body(m): - self.meth1(m) - - with self.meth3.body(m): - self.meth1(m) - - with self.meth4.body(m): - self.meth2(m) - self.meth3(m) - - return m - - m = Diamond() - self.assert_re("called twice", AdapterCircuit(m, [m.meth4])) - - def test_loop(self): - class Loop(Elaboratable): - def __init__(self): - self.meth1 = Method() - - def elaborate(self, platform): - m = TModule() - - with self.meth1.body(m): - self.meth1(m) - - return m - - m = Loop() - self.assert_re("called twice", AdapterCircuit(m, [m.meth1])) - - def test_cycle(self): - class Cycle(Elaboratable): - def __init__(self): - self.meth1 = Method() - self.meth2 = Method() - - def elaborate(self, platform): - m = TModule() - - with self.meth1.body(m): - self.meth2(m) - - with self.meth2.body(m): - self.meth1(m) - - return m - - m = Cycle() - self.assert_re("called twice", AdapterCircuit(m, [m.meth1])) - - def test_redefine(self): - class Redefine(Elaboratable): - def elaborate(self, platform): - m = TModule() - m._MustUse__silence = True # type: ignore - - meth = Method() - - with meth.body(m): - pass - - with meth.body(m): - pass - - self.assert_re("already defined", Redefine()) - - def test_undefined_in_trans(self): - class Undefined(Elaboratable): - def __init__(self): - self.meth = Method(i=data_layout(1)) - - def elaborate(self, platform): - return TModule() - - class Circuit(Elaboratable): - def elaborate(self, platform): - m = TModule() - - m.submodules.undefined = undefined = Undefined() - m.submodules.adapter = AdapterTrans(undefined.meth) - - return m - - self.assert_re("not defined", Circuit()) - - -WIDTH = 8 - - -class Quadruple(Elaboratable): - def __init__(self): - layout = data_layout(WIDTH) - self.id = Method(i=layout, o=layout) - self.double = Method(i=layout, o=layout) - self.quadruple = Method(i=layout, o=layout) - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.id) - def _(arg): - return arg - - @def_method(m, self.double) - def _(arg): - return {"data": self.id(m, arg).data * 2} - - @def_method(m, self.quadruple) - def _(arg): - return {"data": self.double(m, arg).data * 2} - - return m - - -class QuadrupleCircuit(Elaboratable): - def __init__(self, quadruple): - self.quadruple = quadruple - - def elaborate(self, platform): - m = TModule() - - m.submodules.quadruple = self.quadruple - m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(self.quadruple.quadruple)) - - return m - - -class Quadruple2(Elaboratable): - def __init__(self): - layout = data_layout(WIDTH) - self.quadruple = Method(i=layout, o=layout) - - def elaborate(self, platform): - m = TModule() - - m.submodules.sub = Quadruple() - - @def_method(m, self.quadruple) - def _(arg): - return {"data": 2 * m.submodules.sub.double(m, arg).data} - - return m - - -class TestQuadrupleCircuits(TestCaseWithSimulator): - @parameterized.expand([(Quadruple,), (Quadruple2,)]) - def test(self, quadruple): - circ = QuadrupleCircuit(quadruple()) - - def process(): - for n in range(1 << (WIDTH - 2)): - out = yield from circ.tb.call(data=n) - assert out["data"] == n * 4 - - with self.run_simulation(circ) as sim: - sim.add_process(process) - - -class ConditionalCallCircuit(Elaboratable): - def elaborate(self, platform): - m = TModule() - - meth = Method(i=data_layout(1)) - - m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) - m.submodules.out = self.out = TestbenchIO(Adapter()) - - @def_method(m, meth) - def _(arg): - with m.If(arg): - self.out.adapter.iface(m) - - return m - - -class ConditionalMethodCircuit1(Elaboratable): - def elaborate(self, platform): - m = TModule() - - meth = Method() - - self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) - - @def_method(m, meth, ready=self.ready) - def _(arg): - pass - - return m - - -class ConditionalMethodCircuit2(Elaboratable): - def elaborate(self, platform): - m = TModule() - - meth = Method() - - self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) - - with m.If(self.ready): - - @def_method(m, meth) - def _(arg): - pass - - return m - - -class ConditionalTransactionCircuit1(Elaboratable): - def elaborate(self, platform): - m = TModule() - - self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(Adapter()) - - with Transaction().body(m, request=self.ready): - self.tb.adapter.iface(m) - - return m - - -class ConditionalTransactionCircuit2(Elaboratable): - def elaborate(self, platform): - m = TModule() - - self.ready = Signal() - m.submodules.tb = self.tb = TestbenchIO(Adapter()) - - with m.If(self.ready): - with Transaction().body(m): - self.tb.adapter.iface(m) - - return m - - -class TestConditionals(TestCaseWithSimulator): - def test_conditional_call(self): - circ = ConditionalCallCircuit() - - def process(): - yield from circ.out.disable() - yield from circ.tb.call_init(data=0) - yield Settle() - assert not (yield from circ.out.done()) - assert not (yield from circ.tb.done()) - - yield from circ.out.enable() - yield Settle() - assert not (yield from circ.out.done()) - assert (yield from circ.tb.done()) - - yield from circ.tb.call_init(data=1) - yield Settle() - assert (yield from circ.out.done()) - assert (yield from circ.tb.done()) - - # the argument is still 1 but the method is not called - yield from circ.tb.disable() - yield Settle() - assert not (yield from circ.out.done()) - assert not (yield from circ.tb.done()) - - with self.run_simulation(circ) as sim: - sim.add_process(process) - - @parameterized.expand( - [ - (ConditionalMethodCircuit1,), - (ConditionalMethodCircuit2,), - (ConditionalTransactionCircuit1,), - (ConditionalTransactionCircuit2,), - ] - ) - def test_conditional(self, elaboratable): - circ = elaboratable() - - def process(): - yield from circ.tb.enable() - yield circ.ready.eq(0) - yield Settle() - assert not (yield from circ.tb.done()) - - yield circ.ready.eq(1) - yield Settle() - assert (yield from circ.tb.done()) - - with self.run_simulation(circ) as sim: - sim.add_process(process) - - -class NonexclusiveMethodCircuit(Elaboratable): - def elaborate(self, platform): - m = TModule() - - self.ready = Signal() - self.running = Signal() - self.data = Signal(WIDTH) - - method = Method(o=data_layout(WIDTH), nonexclusive=True) - - @def_method(m, method, self.ready) - def _(): - m.d.comb += self.running.eq(1) - return {"data": self.data} - - m.submodules.t1 = self.t1 = TestbenchIO(AdapterTrans(method)) - m.submodules.t2 = self.t2 = TestbenchIO(AdapterTrans(method)) - - return m - - -class TestNonexclusiveMethod(TestCaseWithSimulator): - def test_nonexclusive_method(self): - circ = NonexclusiveMethodCircuit() - - def process(): - for x in range(8): - t1en = bool(x & 1) - t2en = bool(x & 2) - mrdy = bool(x & 4) - - if t1en: - yield from circ.t1.enable() - else: - yield from circ.t1.disable() - - if t2en: - yield from circ.t2.enable() - else: - yield from circ.t2.disable() - - if mrdy: - yield circ.ready.eq(1) - else: - yield circ.ready.eq(0) - - yield circ.data.eq(x) - yield Settle() - - assert bool((yield circ.running)) == ((t1en or t2en) and mrdy) - assert bool((yield from circ.t1.done())) == (t1en and mrdy) - assert bool((yield from circ.t2.done())) == (t2en and mrdy) - - if t1en and mrdy: - assert (yield from circ.t1.get_outputs()) == {"data": x} - - if t2en and mrdy: - assert (yield from circ.t2.get_outputs()) == {"data": x} - - with self.run_simulation(circ) as sim: - sim.add_process(process) - - -class TwoNonexclusiveConflictCircuit(Elaboratable): - def __init__(self, two_nonexclusive: bool): - self.two_nonexclusive = two_nonexclusive - - def elaborate(self, platform): - m = TModule() - - self.running1 = Signal() - self.running2 = Signal() - - method1 = Method(o=data_layout(WIDTH), nonexclusive=True) - method2 = Method(o=data_layout(WIDTH), nonexclusive=self.two_nonexclusive) - method_in = Method(o=data_layout(WIDTH)) - - @def_method(m, method_in) - def _(): - return {"data": 0} - - @def_method(m, method1) - def _(): - m.d.comb += self.running1.eq(1) - return method_in(m) - - @def_method(m, method2) - def _(): - m.d.comb += self.running2.eq(1) - return method_in(m) - - m.submodules.t1 = self.t1 = TestbenchIO(AdapterTrans(method1)) - m.submodules.t2 = self.t2 = TestbenchIO(AdapterTrans(method2)) - - return m - - -class TestConflicting(TestCaseWithSimulator): - @pytest.mark.parametrize( - "test_circuit", [lambda: TwoNonexclusiveConflictCircuit(False), lambda: TwoNonexclusiveConflictCircuit(True)] - ) - def test_conflicting(self, test_circuit: Callable[[], TwoNonexclusiveConflictCircuit]): - circ = test_circuit() - - def process(): - yield from circ.t1.enable() - yield from circ.t2.enable() - yield Settle() - - assert not (yield circ.running1) or not (yield circ.running2) - - with self.run_simulation(circ) as sim: - sim.add_process(process) - - -class CustomCombinerMethodCircuit(Elaboratable): - def elaborate(self, platform): - m = TModule() - - self.ready = Signal() - self.running = Signal() - - def combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg: - result = C(0) - for i, v in enumerate(args): - result = result ^ Mux(runs[i], v.data, 0) - return {"data": result} - - method = Method(i=data_layout(WIDTH), o=data_layout(WIDTH), nonexclusive=True, combiner=combiner) - - @def_method(m, method, self.ready) - def _(data: Value): - m.d.comb += self.running.eq(1) - return {"data": data} - - m.submodules.t1 = self.t1 = TestbenchIO(AdapterTrans(method)) - m.submodules.t2 = self.t2 = TestbenchIO(AdapterTrans(method)) - - return m - - -class TestCustomCombinerMethod(TestCaseWithSimulator): - def test_custom_combiner_method(self): - circ = CustomCombinerMethodCircuit() - - def process(): - for x in range(8): - t1en = bool(x & 1) - t2en = bool(x & 2) - mrdy = bool(x & 4) - - val1 = random.randrange(0, 2**WIDTH) - val2 = random.randrange(0, 2**WIDTH) - val1e = val1 if t1en else 0 - val2e = val2 if t2en else 0 - - yield from circ.t1.call_init(data=val1) - yield from circ.t2.call_init(data=val2) - - if t1en: - yield from circ.t1.enable() - else: - yield from circ.t1.disable() - - if t2en: - yield from circ.t2.enable() - else: - yield from circ.t2.disable() - - if mrdy: - yield circ.ready.eq(1) - else: - yield circ.ready.eq(0) - - yield Settle() - - assert bool((yield circ.running)) == ((t1en or t2en) and mrdy) - assert bool((yield from circ.t1.done())) == (t1en and mrdy) - assert bool((yield from circ.t2.done())) == (t2en and mrdy) - - if t1en and mrdy: - assert (yield from circ.t1.get_outputs()) == {"data": val1e ^ val2e} - - if t2en and mrdy: - assert (yield from circ.t2.get_outputs()) == {"data": val1e ^ val2e} - - with self.run_simulation(circ) as sim: - sim.add_process(process) - - -class DataDependentConditionalCircuit(Elaboratable): - def __init__(self, n=2, ready_function=lambda arg: arg.data != 3): - self.method = Method(i=data_layout(n)) - self.ready_function = ready_function - - self.in_t1 = Signal(from_method_layout(data_layout(n))) - self.in_t2 = Signal(from_method_layout(data_layout(n))) - self.ready = Signal() - self.req_t1 = Signal() - self.req_t2 = Signal() - - self.out_m = Signal() - self.out_t1 = Signal() - self.out_t2 = Signal() - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.method, self.ready, validate_arguments=self.ready_function) - def _(data): - m.d.comb += self.out_m.eq(1) - - with Transaction().body(m, request=self.req_t1): - m.d.comb += self.out_t1.eq(1) - self.method(m, self.in_t1) - - with Transaction().body(m, request=self.req_t2): - m.d.comb += self.out_t2.eq(1) - self.method(m, self.in_t2) - - return m - - -class TestDataDependentConditionalMethod(TestCaseWithSimulator): - def setup_method(self): - self.test_number = 200 - self.bad_number = 3 - self.n = 2 - - def base_random(self, f): - random.seed(14) - self.circ = DataDependentConditionalCircuit(n=self.n, ready_function=f) - - def process(): - for _ in range(self.test_number): - in1 = random.randrange(0, 2**self.n) - in2 = random.randrange(0, 2**self.n) - m_ready = random.randrange(2) - req_t1 = random.randrange(2) - req_t2 = random.randrange(2) - - yield self.circ.in_t1.eq(in1) - yield self.circ.in_t2.eq(in2) - yield self.circ.req_t1.eq(req_t1) - yield self.circ.req_t2.eq(req_t2) - yield self.circ.ready.eq(m_ready) - yield Settle() - - out_m = yield self.circ.out_m - out_t1 = yield self.circ.out_t1 - out_t2 = yield self.circ.out_t2 - - if not m_ready or (not req_t1 or in1 == self.bad_number) and (not req_t2 or in2 == self.bad_number): - assert out_m == 0 - assert out_t1 == 0 - assert out_t2 == 0 - continue - # Here method global ready signal is high and we requested one of the transactions - # we also know that one of the transactions request correct input data - - assert out_m == 1 - assert out_t1 ^ out_t2 == 1 - # inX == self.bad_number implies out_tX==0 - assert in1 != self.bad_number or not out_t1 - assert in2 != self.bad_number or not out_t2 - - yield Tick() - - with self.run_simulation(self.circ, 100) as sim: - sim.add_process(process) - - def test_random_arg(self): - self.base_random(lambda arg: arg.data != self.bad_number) - - def test_random_kwarg(self): - self.base_random(lambda data: data != self.bad_number) diff --git a/test/transactron/test_metrics.py b/test/transactron/test_metrics.py deleted file mode 100644 index c7a0f0765..000000000 --- a/test/transactron/test_metrics.py +++ /dev/null @@ -1,547 +0,0 @@ -import json -import random -import queue -from typing import Type -from enum import IntFlag, IntEnum, auto, Enum - -from parameterized import parameterized_class - -from amaranth import * -from amaranth.sim import Settle, Tick - -from transactron.lib.metrics import * -from transactron import * -from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit -from transactron.testing.infrastructure import Now -from transactron.utils.dependencies import DependencyContext - - -class CounterInMethodCircuit(Elaboratable): - def __init__(self): - self.method = Method() - self.counter = HwCounter("in_method") - - def elaborate(self, platform): - m = TModule() - - m.submodules.counter = self.counter - - @def_method(m, self.method) - def _(): - self.counter.incr(m) - - return m - - -class CounterWithConditionInMethodCircuit(Elaboratable): - def __init__(self): - self.method = Method(i=[("cond", 1)]) - self.counter = HwCounter("with_condition_in_method") - - def elaborate(self, platform): - m = TModule() - - m.submodules.counter = self.counter - - @def_method(m, self.method) - def _(cond): - self.counter.incr(m, cond=cond) - - return m - - -class CounterWithoutMethodCircuit(Elaboratable): - def __init__(self): - self.cond = Signal() - self.counter = HwCounter("with_condition_without_method") - - def elaborate(self, platform): - m = TModule() - - m.submodules.counter = self.counter - - with Transaction().body(m): - self.counter.incr(m, cond=self.cond) - - return m - - -class TestHwCounter(TestCaseWithSimulator): - def setup_method(self) -> None: - random.seed(42) - - def test_counter_in_method(self): - m = SimpleTestCircuit(CounterInMethodCircuit()) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - def test_process(): - called_cnt = 0 - for _ in range(200): - call_now = random.randint(0, 1) == 0 - - if call_now: - yield from m.method.call() - else: - yield Tick() - - # Note that it takes one cycle to update the register value, so here - # we are comparing the "previous" values. - assert called_cnt == (yield m._dut.counter.count.value) - - if call_now: - called_cnt += 1 - - with self.run_simulation(m) as sim: - sim.add_process(test_process) - - def test_counter_with_condition_in_method(self): - m = SimpleTestCircuit(CounterWithConditionInMethodCircuit()) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - def test_process(): - called_cnt = 0 - for _ in range(200): - call_now = random.randint(0, 1) == 0 - condition = random.randint(0, 1) - - if call_now: - yield from m.method.call(cond=condition) - else: - yield Tick() - - # Note that it takes one cycle to update the register value, so here - # we are comparing the "previous" values. - assert called_cnt == (yield m._dut.counter.count.value) - - if call_now and condition == 1: - called_cnt += 1 - - with self.run_simulation(m) as sim: - sim.add_process(test_process) - - def test_counter_with_condition_without_method(self): - m = CounterWithoutMethodCircuit() - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - def test_process(): - called_cnt = 0 - for _ in range(200): - condition = random.randint(0, 1) - - yield m.cond.eq(condition) - yield Tick() - - # Note that it takes one cycle to update the register value, so here - # we are comparing the "previous" values. - assert called_cnt == (yield m.counter.count.value) - - if condition == 1: - called_cnt += 1 - - with self.run_simulation(m) as sim: - sim.add_process(test_process) - - -class OneHotEnum(IntFlag): - ADD = auto() - XOR = auto() - OR = auto() - - -class PlainIntEnum(IntEnum): - TEST_1 = auto() - TEST_2 = auto() - TEST_3 = auto() - - -class TaggedCounterCircuit(Elaboratable): - def __init__(self, tags: range | Type[Enum] | list[int]): - self.counter = TaggedCounter("counter", "", tags=tags) - - self.cond = Signal() - self.tag = Signal(self.counter.tag_width) - - def elaborate(self, platform): - m = TModule() - - m.submodules.counter = self.counter - - with Transaction().body(m): - self.counter.incr(m, self.tag, cond=self.cond) - - return m - - -class TestTaggedCounter(TestCaseWithSimulator): - def setup_method(self) -> None: - random.seed(42) - - def do_test_enum(self, tags: range | Type[Enum] | list[int], tag_values: list[int]): - m = TaggedCounterCircuit(tags) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - counts: dict[int, int] = {} - for i in tag_values: - counts[i] = 0 - - def test_process(): - for _ in range(200): - for i in tag_values: - assert counts[i] == (yield m.counter.counters[i].value) - - tag = random.choice(list(tag_values)) - - yield m.cond.eq(1) - yield m.tag.eq(tag) - yield Tick() - yield m.cond.eq(0) - yield Tick() - - counts[tag] += 1 - - with self.run_simulation(m) as sim: - sim.add_process(test_process) - - def test_one_hot_enum(self): - self.do_test_enum(OneHotEnum, [e.value for e in OneHotEnum]) - - def test_plain_int_enum(self): - self.do_test_enum(PlainIntEnum, [e.value for e in PlainIntEnum]) - - def test_negative_range(self): - r = range(-10, 15, 3) - self.do_test_enum(r, list(r)) - - def test_positive_range(self): - r = range(0, 30, 2) - self.do_test_enum(r, list(r)) - - def test_value_list(self): - values = [-2137, 2, 4, 8, 42] - self.do_test_enum(values, values) - - -class ExpHistogramCircuit(Elaboratable): - def __init__(self, bucket_cnt: int, sample_width: int): - self.sample_width = sample_width - - self.method = Method(i=data_layout(32)) - self.histogram = HwExpHistogram("histogram", bucket_count=bucket_cnt, sample_width=sample_width) - - def elaborate(self, platform): - m = TModule() - - m.submodules.histogram = self.histogram - - @def_method(m, self.method) - def _(data): - self.histogram.add(m, data[0 : self.sample_width]) - - return m - - -@parameterized_class( - ("bucket_count", "sample_width"), - [ - (5, 5), # last bucket is [8, inf), max sample=31 - (8, 5), # last bucket is [64, inf), max sample=31 - (8, 6), # last bucket is [64, inf), max sample=63 - (8, 20), # last bucket is [64, inf), max sample=big - ], -) -class TestHwHistogram(TestCaseWithSimulator): - bucket_count: int - sample_width: int - - def test_histogram(self): - random.seed(42) - - m = SimpleTestCircuit(ExpHistogramCircuit(bucket_cnt=self.bucket_count, sample_width=self.sample_width)) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - max_sample_value = 2**self.sample_width - 1 - - def test_process(): - min = max_sample_value + 1 - max = 0 - sum = 0 - count = 0 - - buckets = [0] * self.bucket_count - - for _ in range(500): - if random.randrange(3) == 0: - value = random.randint(0, max_sample_value) - if value < min: - min = value - if value > max: - max = value - sum += value - count += 1 - for i in range(self.bucket_count): - if value < 2**i or i == self.bucket_count - 1: - buckets[i] += 1 - break - yield from m.method.call(data=value) - yield Tick() - else: - yield Tick() - - histogram = m._dut.histogram - # Skip the assertion if the min is still uninitialized - if min != max_sample_value + 1: - assert min == (yield histogram.min.value) - - assert max == (yield histogram.max.value) - assert sum == (yield histogram.sum.value) - assert count == (yield histogram.count.value) - - total_count = 0 - for i in range(self.bucket_count): - bucket_value = yield histogram.buckets[i].value - total_count += bucket_value - assert buckets[i] == bucket_value - - # Sanity check if all buckets sum up to the total count value - assert total_count == (yield histogram.count.value) - - with self.run_simulation(m) as sim: - sim.add_process(test_process) - - -class TestLatencyMeasurerBase(TestCaseWithSimulator): - def check_latencies(self, m: SimpleTestCircuit, latencies: list[int]): - assert min(latencies) == (yield m._dut.histogram.min.value) - assert max(latencies) == (yield m._dut.histogram.max.value) - assert sum(latencies) == (yield m._dut.histogram.sum.value) - assert len(latencies) == (yield m._dut.histogram.count.value) - - for i in range(m._dut.histogram.bucket_count): - bucket_start = 0 if i == 0 else 2 ** (i - 1) - bucket_end = 1e10 if i == m._dut.histogram.bucket_count - 1 else 2**i - - count = sum(1 for x in latencies if bucket_start <= x < bucket_end) - assert count == (yield m._dut.histogram.buckets[i].value) - - -@parameterized_class( - ("slots_number", "expected_consumer_wait"), - [ - (2, 5), - (2, 10), - (5, 10), - (10, 1), - (10, 10), - (5, 5), - ], -) -class TestFIFOLatencyMeasurer(TestLatencyMeasurerBase): - slots_number: int - expected_consumer_wait: float - - def test_latency_measurer(self): - random.seed(42) - - m = SimpleTestCircuit(FIFOLatencyMeasurer("latency", slots_number=self.slots_number, max_latency=300)) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - latencies: list[int] = [] - - event_queue = queue.Queue() - - finish = False - - def producer(): - nonlocal finish - - for _ in range(200): - yield from m._start.call() - - # Make sure that the time is updated first. - yield Settle() - time = yield Now() - event_queue.put(time) - yield from self.random_wait_geom(0.8) - - finish = True - - def consumer(): - while not finish: - yield from m._stop.call() - - # Make sure that the time is updated first. - yield Settle() - time = yield Now() - latencies.append(time - event_queue.get()) - - yield from self.random_wait_geom(1.0 / self.expected_consumer_wait) - - self.check_latencies(m, latencies) - - with self.run_simulation(m) as sim: - sim.add_process(producer) - sim.add_process(consumer) - - -@parameterized_class( - ("slots_number", "expected_consumer_wait"), - [ - (2, 5), - (2, 10), - (5, 10), - (10, 1), - (10, 10), - (5, 5), - ], -) -class TestIndexedLatencyMeasurer(TestLatencyMeasurerBase): - slots_number: int - expected_consumer_wait: float - - def test_latency_measurer(self): - random.seed(42) - - m = SimpleTestCircuit(TaggedLatencyMeasurer("latency", slots_number=self.slots_number, max_latency=300)) - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - latencies: list[int] = [] - - events = list(0 for _ in range(self.slots_number)) - free_slots = list(k for k in range(self.slots_number)) - used_slots: list[int] = [] - - finish = False - - def producer(): - nonlocal finish - - for _ in range(200): - while not free_slots: - yield Tick() - continue - yield Settle() - - slot_id = random.choice(free_slots) - yield from m._start.call(slot=slot_id) - - time = yield Now() - - events[slot_id] = time - free_slots.remove(slot_id) - used_slots.append(slot_id) - - yield from self.random_wait_geom(0.8) - - finish = True - - def consumer(): - while not finish: - while not used_slots: - yield Tick() - continue - - slot_id = random.choice(used_slots) - - yield from m._stop.call(slot=slot_id) - - time = yield Now() - - yield Settle() - yield Settle() - - latencies.append(time - events[slot_id]) - used_slots.remove(slot_id) - free_slots.append(slot_id) - - yield from self.random_wait_geom(1.0 / self.expected_consumer_wait) - - self.check_latencies(m, latencies) - - with self.run_simulation(m) as sim: - sim.add_process(producer) - sim.add_process(consumer) - - -class MetricManagerTestCircuit(Elaboratable): - def __init__(self): - self.incr_counters = Method(i=[("counter1", 1), ("counter2", 1), ("counter3", 1)]) - - self.counter1 = HwCounter("foo.counter1", "this is the description") - self.counter2 = HwCounter("bar.baz.counter2") - self.counter3 = HwCounter("bar.baz.counter3", "yet another description") - - def elaborate(self, platform): - m = TModule() - - m.submodules += [self.counter1, self.counter2, self.counter3] - - @def_method(m, self.incr_counters) - def _(counter1, counter2, counter3): - self.counter1.incr(m, cond=counter1) - self.counter2.incr(m, cond=counter2) - self.counter3.incr(m, cond=counter3) - - return m - - -class TestMetricsManager(TestCaseWithSimulator): - def test_metrics_metadata(self): - # We need to initialize the circuit to make sure that metrics are registered - # in the dependency manager. - m = MetricManagerTestCircuit() - metrics_manager = HardwareMetricsManager() - - # Run the simulation so Amaranth doesn't scream that we have unused elaboratables. - with self.run_simulation(m): - pass - - assert metrics_manager.get_metrics()["foo.counter1"].to_json() == json.dumps( # type: ignore - { - "fully_qualified_name": "foo.counter1", - "description": "this is the description", - "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, - } - ) - - assert metrics_manager.get_metrics()["bar.baz.counter2"].to_json() == json.dumps( # type: ignore - { - "fully_qualified_name": "bar.baz.counter2", - "description": "", - "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, - } - ) - - assert metrics_manager.get_metrics()["bar.baz.counter3"].to_json() == json.dumps( # type: ignore - { - "fully_qualified_name": "bar.baz.counter3", - "description": "yet another description", - "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, - } - ) - - def test_returned_reg_values(self): - random.seed(42) - - m = SimpleTestCircuit(MetricManagerTestCircuit()) - metrics_manager = HardwareMetricsManager() - - DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) - - def test_process(): - counters = [0] * 3 - for _ in range(200): - rand = [random.randint(0, 1) for _ in range(3)] - - yield from m.incr_counters.call(counter1=rand[0], counter2=rand[1], counter3=rand[2]) - yield Tick() - - for i in range(3): - if rand[i] == 1: - counters[i] += 1 - - assert counters[0] == (yield metrics_manager.get_register_value("foo.counter1", "count")) - assert counters[1] == (yield metrics_manager.get_register_value("bar.baz.counter2", "count")) - assert counters[2] == (yield metrics_manager.get_register_value("bar.baz.counter3", "count")) - - with self.run_simulation(m) as sim: - sim.add_process(test_process) diff --git a/test/transactron/test_simultaneous.py b/test/transactron/test_simultaneous.py deleted file mode 100644 index d0859301d..000000000 --- a/test/transactron/test_simultaneous.py +++ /dev/null @@ -1,172 +0,0 @@ -import pytest -from itertools import product -from typing import Optional -from amaranth import * -from amaranth.sim import * - -from transactron.utils import ModuleConnector - -from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchIO, def_method_mock - -from transactron import * -from transactron.lib import Adapter, Connect, ConnectTrans - - -def empty_method(m: TModule, method: Method): - @def_method(m, method) - def _(): - pass - - -class SimultaneousDiamondTestCircuit(Elaboratable): - def __init__(self): - self.method_l = Method() - self.method_r = Method() - self.method_u = Method() - self.method_d = Method() - - def elaborate(self, platform): - m = TModule() - - empty_method(m, self.method_l) - empty_method(m, self.method_r) - empty_method(m, self.method_u) - empty_method(m, self.method_d) - - # the only possibilities for the following are: (l, u, r) or (l, d, r) - self.method_l.simultaneous_alternatives(self.method_u, self.method_d) - self.method_r.simultaneous_alternatives(self.method_u, self.method_d) - - return m - - -class TestSimultaneousDiamond(TestCaseWithSimulator): - def test_diamond(self): - circ = SimpleTestCircuit(SimultaneousDiamondTestCircuit()) - - def process(): - methods = {"l": circ.method_l, "r": circ.method_r, "u": circ.method_u, "d": circ.method_d} - for i in range(1 << len(methods)): - enables: dict[str, bool] = {} - for k, n in enumerate(methods): - enables[n] = bool(i & (1 << k)) - yield from methods[n].set_enable(enables[n]) - yield Tick() - dones: dict[str, bool] = {} - for n in methods: - dones[n] = bool((yield from methods[n].done())) - for n in methods: - if not enables[n]: - assert not dones[n] - if enables["l"] and enables["r"] and (enables["u"] or enables["d"]): - assert dones["l"] - assert dones["r"] - assert dones["u"] or dones["d"] - else: - assert not any(dones.values()) - - with self.run_simulation(circ) as sim: - sim.add_process(process) - - -class UnsatisfiableTriangleTestCircuit(Elaboratable): - def __init__(self): - self.method_l = Method() - self.method_u = Method() - self.method_d = Method() - - def elaborate(self, platform): - m = TModule() - - empty_method(m, self.method_l) - empty_method(m, self.method_u) - empty_method(m, self.method_d) - - # the following is unsatisfiable - self.method_l.simultaneous_alternatives(self.method_u, self.method_d) - self.method_u.simultaneous(self.method_d) - - return m - - -class TestUnsatisfiableTriangle(TestCaseWithSimulator): - def test_unsatisfiable(self): - circ = SimpleTestCircuit(UnsatisfiableTriangleTestCircuit()) - - with pytest.raises(RuntimeError): - with self.run_simulation(circ) as _: - pass - - -class HelperConnect(Elaboratable): - def __init__(self, source: Method, target: Method, request: Signal, data: int): - self.source = source - self.target = target - self.request = request - self.data = data - - def elaborate(self, platform): - m = TModule() - - with Transaction().body(m, request=self.request): - self.target(m, self.data ^ self.source(m).data) - - return m - - -class TransitivityTestCircuit(Elaboratable): - def __init__(self, target: Method, req1: Signal, req2: Signal): - self.source1 = Method(i=[("data", 2)]) - self.source2 = Method(i=[("data", 2)]) - self.target = target - self.req1 = req1 - self.req2 = req2 - - def elaborate(self, platform): - m = TModule() - - m.submodules.c1 = c1 = Connect([("data", 2)]) - m.submodules.c2 = c2 = Connect([("data", 2)]) - self.source1.proxy(m, c1.write) - self.source2.proxy(m, c1.write) - m.submodules.ct = ConnectTrans(c2.read, self.target) - m.submodules.hc1 = HelperConnect(c1.read, c2.write, self.req1, 1) - m.submodules.hc2 = HelperConnect(c1.read, c2.write, self.req2, 2) - - return m - - -class TestTransitivity(TestCaseWithSimulator): - def test_transitivity(self): - target = TestbenchIO(Adapter(i=[("data", 2)])) - req1 = Signal() - req2 = Signal() - - circ = SimpleTestCircuit(TransitivityTestCircuit(target.adapter.iface, req1, req2)) - m = ModuleConnector(test_circuit=circ, target=target) - - result: Optional[int] - - @def_method_mock(lambda: target) - def target_process(data): - nonlocal result - result = data - - def process(): - nonlocal result - for source, data, reqv1, reqv2 in product([circ.source1, circ.source2], [0, 1, 2, 3], [0, 1], [0, 1]): - result = None - yield req1.eq(reqv1) - yield req2.eq(reqv2) - call_result = yield from source.call_try(data=data) - - if not reqv1 and not reqv2: - assert call_result is None - assert result is None - else: - assert call_result is not None - possibles = reqv1 * [data ^ 1] + reqv2 * [data ^ 2] - assert result in possibles - - with self.run_simulation(m) as sim: - sim.add_process(process) diff --git a/test/transactron/test_transactron_lib_storage.py b/test/transactron/test_transactron_lib_storage.py deleted file mode 100644 index 404c14a2d..000000000 --- a/test/transactron/test_transactron_lib_storage.py +++ /dev/null @@ -1,135 +0,0 @@ -from datetime import timedelta -from hypothesis import given, settings, Phase -from transactron.testing import * -from transactron.lib.storage import ContentAddressableMemory - - -class TestContentAddressableMemory(TestCaseWithSimulator): - addr_width = 4 - content_width = 5 - test_number = 30 - nop_number = 3 - addr_layout = data_layout(addr_width) - content_layout = data_layout(content_width) - - def setUp(self): - self.entries_count = 8 - - self.circ = SimpleTestCircuit( - ContentAddressableMemory(self.addr_layout, self.content_layout, self.entries_count) - ) - - self.memory = {} - - def generic_process( - self, - method, - input_lst, - behaviour_check=None, - state_change=None, - input_verification=None, - settle_count=0, - name="", - ): - def f(): - while input_lst: - # wait till all processes will end the previous cycle - yield from self.multi_settle(4) - elem = input_lst.pop() - if isinstance(elem, OpNOP): - yield Tick() - continue - if input_verification is not None and not input_verification(elem): - yield Tick() - continue - response = yield from method.call(**elem) - yield from self.multi_settle(settle_count) - if behaviour_check is not None: - # Here accesses to circuit are allowed - ret = behaviour_check(elem, response) - if isinstance(ret, Generator): - yield from ret - if state_change is not None: - # It is standard python function by purpose to don't allow accessing circuit - state_change(elem, response) - yield Tick() - - return f - - def push_process(self, in_push): - def verify_in(elem): - return not (frozenset(elem["addr"].items()) in self.memory) - - def modify_state(elem, response): - self.memory[frozenset(elem["addr"].items())] = elem["data"] - - return self.generic_process( - self.circ.push, - in_push, - state_change=modify_state, - input_verification=verify_in, - settle_count=3, - name="push", - ) - - def read_process(self, in_read): - def check(elem, response): - addr = elem["addr"] - frozen_addr = frozenset(addr.items()) - if frozen_addr in self.memory: - assert response["not_found"] == 0 - assert response["data"] == self.memory[frozen_addr] - else: - assert response["not_found"] == 1 - - return self.generic_process(self.circ.read, in_read, behaviour_check=check, settle_count=0, name="read") - - def remove_process(self, in_remove): - def modify_state(elem, response): - if frozenset(elem["addr"].items()) in self.memory: - del self.memory[frozenset(elem["addr"].items())] - - return self.generic_process(self.circ.remove, in_remove, state_change=modify_state, settle_count=2, name="remv") - - def write_process(self, in_write): - def verify_in(elem): - ret = frozenset(elem["addr"].items()) in self.memory - return ret - - def check(elem, response): - assert response["not_found"] == int(frozenset(elem["addr"].items()) not in self.memory) - - def modify_state(elem, response): - if frozenset(elem["addr"].items()) in self.memory: - self.memory[frozenset(elem["addr"].items())] = elem["data"] - - return self.generic_process( - self.circ.write, - in_write, - behaviour_check=check, - state_change=modify_state, - input_verification=None, - settle_count=1, - name="writ", - ) - - @settings( - max_examples=10, - phases=(Phase.explicit, Phase.reuse, Phase.generate, Phase.shrink), - derandomize=True, - deadline=timedelta(milliseconds=500), - ) - @given( - generate_process_input(test_number, nop_number, [("addr", addr_layout), ("data", content_layout)]), - generate_process_input(test_number, nop_number, [("addr", addr_layout), ("data", content_layout)]), - generate_process_input(test_number, nop_number, [("addr", addr_layout)]), - generate_process_input(test_number, nop_number, [("addr", addr_layout)]), - ) - def test_random(self, in_push, in_write, in_read, in_remove): - with self.reinitialize_fixtures(): - self.setUp() - with self.run_simulation(self.circ, max_cycles=500) as sim: - sim.add_process(self.push_process(in_push)) - sim.add_process(self.read_process(in_read)) - sim.add_process(self.write_process(in_write)) - sim.add_process(self.remove_process(in_remove)) diff --git a/test/transactron/testing/__init__.py b/test/transactron/testing/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/transactron/testing/test_infrastructure.py b/test/transactron/testing/test_infrastructure.py deleted file mode 100644 index cfd59ec87..000000000 --- a/test/transactron/testing/test_infrastructure.py +++ /dev/null @@ -1,31 +0,0 @@ -from amaranth import * -from transactron.testing import * - - -class EmptyCircuit(Elaboratable): - def __init__(self): - pass - - def elaborate(self, platform): - m = Module() - return m - - -class TestNow(TestCaseWithSimulator): - def setup_method(self): - self.test_cycles = 10 - self.m = SimpleTestCircuit(EmptyCircuit()) - - def process(self): - for k in range(self.test_cycles): - now = yield Now() - assert k == now - # check if second call don't change the returned value - now = yield Now() - assert k == now - - yield Tick() - - def test_random(self): - with self.run_simulation(self.m, 50) as sim: - sim.add_process(self.process) diff --git a/test/transactron/testing/test_log.py b/test/transactron/testing/test_log.py deleted file mode 100644 index 6e6711d8e..000000000 --- a/test/transactron/testing/test_log.py +++ /dev/null @@ -1,124 +0,0 @@ -import pytest -import re -from amaranth import * -from amaranth.sim import Tick - -from transactron import * -from transactron.testing import TestCaseWithSimulator -from transactron.lib import logging - -LOGGER_NAME = "test_logger" - -log = logging.HardwareLogger(LOGGER_NAME) - - -class LogTest(Elaboratable): - def __init__(self): - self.input = Signal(range(100)) - self.counter = Signal(range(200)) - - def elaborate(self, platform): - m = TModule() - - with m.If(self.input == 42): - log.warning(m, True, "Log triggered under Amaranth If value+3=0x{:x}", self.input + 3) - - log.warning(m, self.input[0] == 0, "Input is even! input={}, counter={}", self.input, self.counter) - - m.d.sync += self.counter.eq(self.counter + 1) - - return m - - -class ErrorLogTest(Elaboratable): - def __init__(self): - self.input = Signal() - self.output = Signal() - - def elaborate(self, platform): - m = TModule() - - m.d.comb += self.output.eq(self.input & ~self.input) - - log.error( - m, - self.input != self.output, - "Input is different than output! input=0x{:x} output=0x{:x}", - self.input, - self.output, - ) - - return m - - -class AssertionTest(Elaboratable): - def __init__(self): - self.input = Signal() - self.output = Signal() - - def elaborate(self, platform): - m = TModule() - - m.d.comb += self.output.eq(self.input & ~self.input) - - log.assertion(m, self.input == self.output, "Output differs") - - return m - - -class TestLog(TestCaseWithSimulator): - def test_log(self, caplog): - m = LogTest() - - def proc(): - for i in range(50): - yield Tick() - yield m.input.eq(i) - - with self.run_simulation(m) as sim: - sim.add_process(proc) - - assert re.search( - r"WARNING test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] " - + r"Log triggered under Amaranth If value\+3=0x2d", - caplog.text, - ) - for i in range(0, 50, 2): - assert re.search( - r"WARNING test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] " - + f"Input is even! input={i}, counter={i + 1}", - caplog.text, - ) - - def test_error_log(self, caplog): - m = ErrorLogTest() - - def proc(): - yield Tick() - yield m.input.eq(1) - - with pytest.raises(AssertionError): - with self.run_simulation(m) as sim: - sim.add_process(proc) - - assert re.search( - r"ERROR test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] " - + "Input is different than output! input=0x1 output=0x0", - caplog.text, - ) - - def test_assertion(self, caplog): - m = AssertionTest() - - def proc(): - yield Tick() - yield m.input.eq(1) - - with pytest.raises(AssertionError): - with self.run_simulation(m) as sim: - sim.add_process(proc) - - assert re.search( - r"ERROR test_logger:logging\.py:\d+ \[test/transactron/testing/test_log\.py:\d+\] Output differs", - caplog.text, - ) diff --git a/test/transactron/testing/test_validate_arguments.py b/test/transactron/testing/test_validate_arguments.py deleted file mode 100644 index 18066ff5d..000000000 --- a/test/transactron/testing/test_validate_arguments.py +++ /dev/null @@ -1,59 +0,0 @@ -import random -from amaranth import * -from amaranth.sim import * - -from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout - -from transactron import * -from transactron.testing.sugar import def_method_mock -from transactron.lib import * - - -class ValidateArgumentsTestCircuit(Elaboratable): - def elaborate(self, platform): - m = Module() - - self.method = TestbenchIO(Adapter(i=data_layout(1), o=data_layout(1)).set(with_validate_arguments=True)) - self.caller1 = TestbenchIO(AdapterTrans(self.method.adapter.iface)) - self.caller2 = TestbenchIO(AdapterTrans(self.method.adapter.iface)) - - m.submodules += [self.method, self.caller1, self.caller2] - - return m - - -class TestValidateArguments(TestCaseWithSimulator): - def control_caller(self, caller: TestbenchIO, method: TestbenchIO): - def process(): - for _ in range(100): - val = random.randrange(2) - pre_accepted_val = self.accepted_val - ret = yield from caller.call_try(data=val) - if ret is None: - assert val != pre_accepted_val or val == pre_accepted_val and (yield from method.done()) - else: - assert val == pre_accepted_val - assert ret["data"] == val - - return process - - def validate_arguments(self, data: int): - return data == self.accepted_val - - def changer(self): - for _ in range(50): - yield Tick("sync_neg") - self.accepted_val = 1 - - @def_method_mock(tb_getter=lambda self: self.m.method, validate_arguments=validate_arguments) - def method_mock(self, data: int): - return {"data": data} - - def test_validate_arguments(self): - random.seed(42) - self.m = ValidateArgumentsTestCircuit() - self.accepted_val = 0 - with self.run_simulation(self.m) as sim: - sim.add_process(self.changer) - sim.add_process(self.control_caller(self.m.caller1, self.m.method)) - sim.add_process(self.control_caller(self.m.caller2, self.m.method)) diff --git a/test/transactron/utils/__init__.py b/test/transactron/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/transactron/utils/test_amaranth_ext.py b/test/transactron/utils/test_amaranth_ext.py deleted file mode 100644 index 7943ccb76..000000000 --- a/test/transactron/utils/test_amaranth_ext.py +++ /dev/null @@ -1,135 +0,0 @@ -from transactron.testing import * -import random -from transactron.utils.amaranth_ext import MultiPriorityEncoder, RingMultiPriorityEncoder - - -def get_expected_multi(input_width, output_count, input, *args): - places = [] - for i in range(input_width): - if input % 2: - places.append(i) - input //= 2 - places += [None] * output_count - return places - - -def get_expected_ring(input_width, output_count, input, first, last): - places = [] - input = (input << input_width) + input - if last < first: - last += input_width - for i in range(2 * input_width): - if i >= first and i < last and input % 2: - places.append(i % input_width) - input //= 2 - places += [None] * output_count - return places - - -@pytest.mark.parametrize( - "test_class, verif_f", - [(MultiPriorityEncoder, get_expected_multi), (RingMultiPriorityEncoder, get_expected_ring)], - ids=["MultiPriorityEncoder", "RingMultiPriorityEncoder"], -) -class TestPriorityEncoder(TestCaseWithSimulator): - def process(self, get_expected): - def f(): - for _ in range(self.test_number): - input = random.randrange(2**self.input_width) - first = random.randrange(self.input_width) - last = random.randrange(self.input_width) - yield self.circ.input.eq(input) - try: - yield self.circ.first.eq(first) - yield self.circ.last.eq(last) - except AttributeError: - pass - yield Settle() - expected_output = get_expected(self.input_width, self.output_count, input, first, last) - for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids): - if ex is None: - assert (yield valid) == 0 - else: - assert (yield valid) == 1 - assert (yield real) == ex - yield Delay(1e-7) - - return f - - @pytest.mark.parametrize("input_width", [1, 5, 16, 23, 24]) - @pytest.mark.parametrize("output_count", [1, 3, 4]) - def test_random(self, test_class, verif_f, input_width, output_count): - random.seed(input_width + output_count) - self.test_number = 50 - self.input_width = input_width - self.output_count = output_count - self.circ = test_class(self.input_width, self.output_count) - - with self.run_simulation(self.circ) as sim: - sim.add_process(self.process(verif_f)) - - @pytest.mark.parametrize("name", ["prio_encoder", None]) - def test_static_create_simple(self, test_class, verif_f, name): - random.seed(14) - self.test_number = 50 - self.input_width = 7 - self.output_count = 1 - - class DUT(Elaboratable): - def __init__(self, input_width, output_count, name): - self.input = Signal(input_width) - self.first = Signal(range(input_width)) - self.last = Signal(range(input_width)) - self.output_count = output_count - self.input_width = input_width - self.name = name - - def elaborate(self, platform): - m = Module() - if test_class == MultiPriorityEncoder: - out, val = test_class.create_simple(m, self.input_width, self.input, name=self.name) - else: - out, val = test_class.create_simple( - m, self.input_width, self.input, self.first, self.last, name=self.name - ) - # Save as a list to use common interface in testing - self.outputs = [out] - self.valids = [val] - return m - - self.circ = DUT(self.input_width, self.output_count, name) - - with self.run_simulation(self.circ) as sim: - sim.add_process(self.process(verif_f)) - - @pytest.mark.parametrize("name", ["prio_encoder", None]) - def test_static_create(self, test_class, verif_f, name): - random.seed(14) - self.test_number = 50 - self.input_width = 7 - self.output_count = 2 - - class DUT(Elaboratable): - def __init__(self, input_width, output_count, name): - self.input = Signal(input_width) - self.first = Signal(range(input_width)) - self.last = Signal(range(input_width)) - self.output_count = output_count - self.input_width = input_width - self.name = name - - def elaborate(self, platform): - m = Module() - if test_class == MultiPriorityEncoder: - out = test_class.create(m, self.input_width, self.input, self.output_count, name=self.name) - else: - out = test_class.create( - m, self.input_width, self.input, self.first, self.last, self.output_count, name=self.name - ) - self.outputs, self.valids = list(zip(*out)) - return m - - self.circ = DUT(self.input_width, self.output_count, name) - - with self.run_simulation(self.circ) as sim: - sim.add_process(self.process(verif_f)) diff --git a/test/transactron/utils/test_onehotswitch.py b/test/transactron/utils/test_onehotswitch.py deleted file mode 100644 index 9d7dc843f..000000000 --- a/test/transactron/utils/test_onehotswitch.py +++ /dev/null @@ -1,62 +0,0 @@ -from amaranth import * -from amaranth.sim import * - -from transactron.utils import OneHotSwitch - -from transactron.testing import TestCaseWithSimulator - -from parameterized import parameterized - - -class OneHotSwitchCircuit(Elaboratable): - def __init__(self, width: int, test_zero: bool): - self.input = Signal(1 << width) - self.output = Signal(width) - self.zero = Signal() - self.test_zero = test_zero - - def elaborate(self, platform): - m = Module() - - with OneHotSwitch(m, self.input) as OneHotCase: - for i in range(len(self.input)): - with OneHotCase(1 << i): - m.d.comb += self.output.eq(i) - - if self.test_zero: - with OneHotCase(): - m.d.comb += self.zero.eq(1) - - return m - - -class TestAssign(TestCaseWithSimulator): - @parameterized.expand([(False,), (True,)]) - def test_onehotswitch(self, test_zero): - circuit = OneHotSwitchCircuit(4, test_zero) - - def switch_test_proc(): - for i in range(len(circuit.input)): - yield circuit.input.eq(1 << i) - yield Settle() - assert (yield circuit.output) == i - - with self.run_simulation(circuit) as sim: - sim.add_process(switch_test_proc) - - def test_onehotswitch_zero(self): - circuit = OneHotSwitchCircuit(4, True) - - def switch_test_proc_zero(): - for i in range(len(circuit.input)): - yield circuit.input.eq(1 << i) - yield Settle() - assert (yield circuit.output) == i - assert not (yield circuit.zero) - - yield circuit.input.eq(0) - yield Settle() - assert (yield circuit.zero) - - with self.run_simulation(circuit) as sim: - sim.add_process(switch_test_proc_zero) diff --git a/test/transactron/utils/test_utils.py b/test/transactron/utils/test_utils.py deleted file mode 100644 index 63c176169..000000000 --- a/test/transactron/utils/test_utils.py +++ /dev/null @@ -1,196 +0,0 @@ -import unittest -import random - -from amaranth import * -from transactron.testing import * -from transactron.utils import ( - align_to_power_of_two, - align_down_to_power_of_two, - popcount, - count_leading_zeros, - count_trailing_zeros, -) -from parameterized import parameterized_class - - -class TestAlignToPowerOfTwo(unittest.TestCase): - def test_align_to_power_of_two(self): - test_cases = [ - (2, 2, 4), - (2, 1, 2), - (3, 1, 4), - (7, 3, 8), - (8, 3, 8), - (14, 3, 16), - (17, 3, 24), - (33, 3, 40), - (33, 1, 34), - (33, 0, 33), - (33, 4, 48), - (33, 5, 64), - (33, 6, 64), - ] - - for num, power, expected in test_cases: - out = align_to_power_of_two(num, power) - assert expected == out - - def test_align_down_to_power_of_two(self): - test_cases = [ - (3, 1, 2), - (3, 0, 3), - (3, 3, 0), - (8, 3, 8), - (8, 2, 8), - (33, 5, 32), - (29, 5, 0), - (29, 1, 28), - (29, 3, 24), - ] - - for num, power, expected in test_cases: - out = align_down_to_power_of_two(num, power) - assert expected == out - - -class PopcountTestCircuit(Elaboratable): - def __init__(self, size: int): - self.sig_in = Signal(size) - self.sig_out = Signal(size) - - def elaborate(self, platform): - m = Module() - - m.d.comb += self.sig_out.eq(popcount(self.sig_in)) - - return m - - -@parameterized_class( - ("name", "size"), - [("size" + str(s), s) for s in [2, 3, 4, 5, 6, 8, 10, 16, 21, 32, 33, 64, 1025]], -) -class TestPopcount(TestCaseWithSimulator): - size: int - - def setup_method(self): - random.seed(14) - self.test_number = 40 - self.m = PopcountTestCircuit(self.size) - - def check(self, n): - yield self.m.sig_in.eq(n) - yield Settle() - out_popcount = yield self.m.sig_out - assert out_popcount == n.bit_count(), f"{n:x}" - - def process(self): - for i in range(self.test_number): - n = random.randrange(2**self.size) - yield from self.check(n) - yield from self.check(2**self.size - 1) - - def test_popcount(self): - with self.run_simulation(self.m) as sim: - sim.add_process(self.process) - - -class CLZTestCircuit(Elaboratable): - def __init__(self, xlen_log: int): - self.sig_in = Signal(1 << xlen_log) - self.sig_out = Signal(xlen_log + 1) - self.xlen_log = xlen_log - - def elaborate(self, platform): - m = Module() - - m.d.comb += self.sig_out.eq(count_leading_zeros(self.sig_in)) - # dummy signal - s = Signal() - m.d.sync += s.eq(1) - - return m - - -@parameterized_class( - ("name", "size"), - [("size" + str(s), s) for s in range(1, 7)], -) -class TestCountLeadingZeros(TestCaseWithSimulator): - size: int - - def setup_method(self): - random.seed(14) - self.test_number = 40 - self.m = CLZTestCircuit(self.size) - - def check(self, n): - yield self.m.sig_in.eq(n) - yield Settle() - out_clz = yield self.m.sig_out - assert out_clz == (2**self.size) - n.bit_length(), f"{n:x}" - - def process(self): - for i in range(self.test_number): - n = random.randrange(2**self.size) - yield from self.check(n) - yield from self.check(2**self.size - 1) - - def test_count_leading_zeros(self): - with self.run_simulation(self.m) as sim: - sim.add_process(self.process) - - -class CTZTestCircuit(Elaboratable): - def __init__(self, xlen_log: int): - self.sig_in = Signal(1 << xlen_log) - self.sig_out = Signal(xlen_log + 1) - self.xlen_log = xlen_log - - def elaborate(self, platform): - m = Module() - - m.d.comb += self.sig_out.eq(count_trailing_zeros(self.sig_in)) - # dummy signal - s = Signal() - m.d.sync += s.eq(1) - - return m - - -@parameterized_class( - ("name", "size"), - [("size" + str(s), s) for s in range(1, 7)], -) -class TestCountTrailingZeros(TestCaseWithSimulator): - size: int - - def setup_method(self): - random.seed(14) - self.test_number = 40 - self.m = CTZTestCircuit(self.size) - - def check(self, n): - yield self.m.sig_in.eq(n) - yield Settle() - out_ctz = yield self.m.sig_out - - expected = 0 - if n == 0: - expected = 2**self.size - else: - while (n & 1) == 0: - expected += 1 - n >>= 1 - - assert out_ctz == expected, f"{n:x}" - - def process(self): - for i in range(self.test_number): - n = random.randrange(2**self.size) - yield from self.check(n) - yield from self.check(2**self.size - 1) - - def test_count_trailing_zeros(self): - with self.run_simulation(self.m) as sim: - sim.add_process(self.process) diff --git a/transactron/__init__.py b/transactron/__init__.py deleted file mode 100644 index c162fe991..000000000 --- a/transactron/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .core import * # noqa: F401 diff --git a/transactron/core/__init__.py b/transactron/core/__init__.py deleted file mode 100644 index 6ead593f8..000000000 --- a/transactron/core/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .tmodule import * # noqa: F401 -from .transaction_base import * # noqa: F401 -from .method import * # noqa: F401 -from .transaction import * # noqa: F401 -from .manager import * # noqa: F401 -from .sugar import * # noqa: F401 diff --git a/transactron/core/keys.py b/transactron/core/keys.py deleted file mode 100644 index 9444dce34..000000000 --- a/transactron/core/keys.py +++ /dev/null @@ -1,13 +0,0 @@ -from transactron.utils import * -from typing import TYPE_CHECKING -from dataclasses import dataclass - -if TYPE_CHECKING: - from .manager import TransactionManager # noqa: F401 because of https://github.com/PyCQA/pyflakes/issues/571 - -__all__ = ["TransactionManagerKey"] - - -@dataclass(frozen=True) -class TransactionManagerKey(SimpleKey["TransactionManager"]): - pass diff --git a/transactron/core/manager.py b/transactron/core/manager.py deleted file mode 100644 index cfbc6b17d..000000000 --- a/transactron/core/manager.py +++ /dev/null @@ -1,537 +0,0 @@ -from collections import defaultdict, deque -from typing import Callable, Iterable, Sequence, TypeAlias, Tuple -from os import environ -from graphlib import TopologicalSorter -from amaranth import * -from amaranth.lib.wiring import Component, connect, flipped -from itertools import chain, filterfalse, product - -from amaranth_types import AbstractComponent - -from transactron.utils import * -from transactron.utils.transactron_helpers import _graph_ccs -from transactron.graph import OwnershipGraph, Direction - -from .transaction_base import TransactionBase, TransactionOrMethod, Priority, Relation -from .method import Method -from .transaction import Transaction, TransactionManagerKey -from .tmodule import TModule -from .schedulers import eager_deterministic_cc_scheduler - -__all__ = ["TransactionManager", "TransactionModule", "TransactionComponent"] - -TransactionGraph: TypeAlias = Graph["Transaction"] -TransactionGraphCC: TypeAlias = GraphCC["Transaction"] -PriorityOrder: TypeAlias = dict["Transaction", int] -TransactionScheduler: TypeAlias = Callable[["MethodMap", TransactionGraph, TransactionGraphCC, PriorityOrder], Module] - - -class MethodMap: - def __init__(self, transactions: Iterable["Transaction"]): - self.methods_by_transaction = dict[Transaction, list[Method]]() - self.transactions_by_method = defaultdict[Method, list[Transaction]](list) - self.readiness_by_call = dict[tuple[Transaction, Method], ValueLike]() - self.ancestors_by_call = dict[tuple[Transaction, Method], tuple[Method, ...]]() - self.method_parents = defaultdict[Method, list[TransactionBase]](list) - - def rec(transaction: Transaction, source: TransactionBase, ancestors: tuple[Method, ...]): - for method, (arg_rec, _) in source.method_uses.items(): - if not method.defined: - raise RuntimeError(f"Trying to use method '{method.name}' which is not defined yet") - if method in self.methods_by_transaction[transaction]: - raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction") - self.methods_by_transaction[transaction].append(method) - self.transactions_by_method[method].append(transaction) - self.readiness_by_call[(transaction, method)] = method._validate_arguments(arg_rec) - self.ancestors_by_call[(transaction, method)] = new_ancestors = (method, *ancestors) - rec(transaction, method, new_ancestors) - - for transaction in transactions: - self.methods_by_transaction[transaction] = [] - rec(transaction, transaction, ()) - - for transaction_or_method in self.methods_and_transactions: - for method in transaction_or_method.method_uses.keys(): - self.method_parents[method].append(transaction_or_method) - - def transactions_for(self, elem: TransactionOrMethod) -> Collection["Transaction"]: - if isinstance(elem, Transaction): - return [elem] - else: - return self.transactions_by_method[elem] - - @property - def methods(self) -> Collection["Method"]: - return self.transactions_by_method.keys() - - @property - def transactions(self) -> Collection["Transaction"]: - return self.methods_by_transaction.keys() - - @property - def methods_and_transactions(self) -> Iterable[TransactionOrMethod]: - return chain(self.methods, self.transactions) - - -class TransactionManager(Elaboratable): - """Transaction manager - - This module is responsible for granting `Transaction`\\s and running - `Method`\\s. It takes care that two conflicting `Transaction`\\s - are never granted in the same clock cycle. - """ - - def __init__(self, cc_scheduler: TransactionScheduler = eager_deterministic_cc_scheduler): - self.transactions: list[Transaction] = [] - self.cc_scheduler = cc_scheduler - - def add_transaction(self, transaction: "Transaction"): - self.transactions.append(transaction) - - @staticmethod - def _conflict_graph(method_map: MethodMap) -> Tuple[TransactionGraph, PriorityOrder]: - """_conflict_graph - - This function generates the graph of transaction conflicts. Conflicts - between transactions can be explicit or implicit. Two transactions - conflict explicitly, if a conflict was added between the transactions - or the methods used by them via `add_conflict`. Two transactions - conflict implicitly if they are both using the same method. - - Created graph is undirected. Transactions are nodes in that graph - and conflict between two transactions is marked as an edge. In such - representation connected components are sets of transactions which can - potentially conflict so there is a need to arbitrate between them. - On the other hand when two transactions are in different connected - components, then they can be scheduled independently, because they - will have no conflicts. - - This function also computes a linear ordering of transactions - which is consistent with conflict priorities of methods and - transactions. When priority constraints cannot be satisfied, - an exception is thrown. - - Returns - ------- - cgr : TransactionGraph - Graph of conflicts between transactions, where vertices are transactions and edges are conflicts. - porder : PriorityOrder - Linear ordering of transactions which is consistent with priority constraints. - """ - - def transactions_exclusive(trans1: Transaction, trans2: Transaction): - tms1 = [trans1] + method_map.methods_by_transaction[trans1] - tms2 = [trans2] + method_map.methods_by_transaction[trans2] - - # if first transaction is exclusive with the second transaction, or this is true for - # any called methods, the transactions will never run at the same time - for tm1, tm2 in product(tms1, tms2): - if tm1.ctrl_path.exclusive_with(tm2.ctrl_path): - return True - - return False - - def calls_nonexclusive(trans1: Transaction, trans2: Transaction, method: Method): - ancestors1 = method_map.ancestors_by_call[(trans1, method)] - ancestors2 = method_map.ancestors_by_call[(trans2, method)] - common_ancestors = longest_common_prefix(ancestors1, ancestors2) - return common_ancestors[-1].nonexclusive - - cgr: TransactionGraph = {} # Conflict graph - pgr: TransactionGraph = {} # Priority graph - - def add_edge(begin: Transaction, end: Transaction, priority: Priority, conflict: bool): - if conflict: - cgr[begin].add(end) - cgr[end].add(begin) - match priority: - case Priority.LEFT: - pgr[end].add(begin) - case Priority.RIGHT: - pgr[begin].add(end) - - for transaction in method_map.transactions: - cgr[transaction] = set() - pgr[transaction] = set() - - for method in method_map.methods: - for transaction1 in method_map.transactions_for(method): - for transaction2 in method_map.transactions_for(method): - if ( - transaction1 is not transaction2 - and not transactions_exclusive(transaction1, transaction2) - and not calls_nonexclusive(transaction1, transaction2, method) - ): - add_edge(transaction1, transaction2, Priority.UNDEFINED, True) - - relations = [ - Relation(**relation, start=elem) - for elem in method_map.methods_and_transactions - for relation in elem.relations - ] - - for relation in relations: - start = relation["start"] - end = relation["end"] - if not relation["conflict"]: # relation added with schedule_before - if end.def_order < start.def_order and not relation["silence_warning"]: - raise RuntimeError(f"{start.name!r} scheduled before {end.name!r}, but defined afterwards") - - for trans_start in method_map.transactions_for(start): - for trans_end in method_map.transactions_for(end): - conflict = relation["conflict"] and not transactions_exclusive(trans_start, trans_end) - add_edge(trans_start, trans_end, relation["priority"], conflict) - - porder: PriorityOrder = {} - - for k, transaction in enumerate(TopologicalSorter(pgr).static_order()): - porder[transaction] = k - - return cgr, porder - - @staticmethod - def _method_enables(method_map: MethodMap) -> Mapping["Transaction", Mapping["Method", ValueLike]]: - method_enables = defaultdict[Transaction, dict[Method, ValueLike]](dict) - enables: list[ValueLike] = [] - - def rec(transaction: Transaction, source: TransactionOrMethod): - for method, (_, enable) in source.method_uses.items(): - enables.append(enable) - rec(transaction, method) - method_enables[transaction][method] = Cat(*enables).all() - enables.pop() - - for transaction in method_map.transactions: - rec(transaction, transaction) - - return method_enables - - @staticmethod - def _method_calls( - m: Module, method_map: MethodMap - ) -> tuple[Mapping["Method", Sequence[MethodStruct]], Mapping["Method", Sequence[Value]]]: - args = defaultdict[Method, list[MethodStruct]](list) - runs = defaultdict[Method, list[Value]](list) - - for source in method_map.methods_and_transactions: - if isinstance(source, Method): - run_val = Cat(transaction.grant for transaction in method_map.transactions_by_method[source]).any() - run = Signal() - m.d.comb += run.eq(run_val) - else: - run = source.grant - for method, (arg, _) in source.method_uses.items(): - args[method].append(arg) - runs[method].append(run) - - return (args, runs) - - def _simultaneous(self): - method_map = MethodMap(self.transactions) - - # remove orderings between simultaneous methods/transactions - # TODO: can it be done after transitivity, possibly catching more cases? - for elem in method_map.methods_and_transactions: - all_sims = frozenset(elem.simultaneous_list) - elem.relations = list( - filterfalse( - lambda relation: not relation["conflict"] - and relation["priority"] != Priority.UNDEFINED - and relation["end"] in all_sims, - elem.relations, - ) - ) - - # step 1: simultaneous and independent sets generation - independents = defaultdict[Transaction, set[Transaction]](set) - - for elem in method_map.methods_and_transactions: - indeps = frozenset[Transaction]().union( - *(frozenset(method_map.transactions_for(ind)) for ind in chain([elem], elem.independent_list)) - ) - for transaction1, transaction2 in product(indeps, indeps): - independents[transaction1].add(transaction2) - - simultaneous = set[frozenset[Transaction]]() - - for elem in method_map.methods_and_transactions: - for sim_elem in elem.simultaneous_list: - for tr1, tr2 in product(method_map.transactions_for(elem), method_map.transactions_for(sim_elem)): - if tr1 in independents[tr2]: - raise RuntimeError( - f"Unsatisfiable simultaneity constraints for '{elem.name}' and '{sim_elem.name}'" - ) - simultaneous.add(frozenset({tr1, tr2})) - - # step 2: transitivity computation - tr_simultaneous = set[frozenset[Transaction]]() - - def conflicting(group: frozenset[Transaction]): - return any(tr1 != tr2 and tr1 in independents[tr2] for tr1 in group for tr2 in group) - - q = deque[frozenset[Transaction]](simultaneous) - - while q: - new_group = q.popleft() - if new_group in tr_simultaneous or conflicting(new_group): - continue - q.extend(new_group | other_group for other_group in simultaneous if new_group & other_group) - tr_simultaneous.add(new_group) - - # step 3: maximal group selection - def maximal(group: frozenset[Transaction]): - return not any(group.issubset(group2) and group != group2 for group2 in tr_simultaneous) - - final_simultaneous = set(filter(maximal, tr_simultaneous)) - - # step 4: convert transactions to methods - joined_transactions = set[Transaction]().union(*final_simultaneous) - - self.transactions = list(filter(lambda t: t not in joined_transactions, self.transactions)) - methods = dict[Transaction, Method]() - - for transaction in joined_transactions: - # TODO: some simpler way? - method = Method(name=transaction.name) - method.owner = transaction.owner - method.src_loc = transaction.src_loc - method.ready = transaction.request - method.run = transaction.grant - method.defined = transaction.defined - method.method_calls = transaction.method_calls - method.method_uses = transaction.method_uses - method.relations = transaction.relations - method.def_order = transaction.def_order - method.ctrl_path = transaction.ctrl_path - methods[transaction] = method - - for elem in method_map.methods_and_transactions: - # I guess method/transaction unification is really needed - for relation in elem.relations: - if relation["end"] in methods: - relation["end"] = methods[relation["end"]] - - # step 5: construct merged transactions - m = TModule() - m._MustUse__silence = True # type: ignore - - for group in final_simultaneous: - name = "_".join([t.name for t in group]) - with Transaction(manager=self, name=name).body(m): - for transaction in group: - methods[transaction](m) - - return m - - def elaborate(self, platform): - # In the following, various problems in the transaction set-up are detected. - # The exception triggers an unused Elaboratable warning. - with silence_mustuse(self): - merge_manager = self._simultaneous() - - method_map = MethodMap(self.transactions) - cgr, porder = TransactionManager._conflict_graph(method_map) - - m = Module() - m.submodules.merge_manager = merge_manager - - for elem in method_map.methods_and_transactions: - elem._set_method_uses(m) - - for transaction in self.transactions: - ready = [ - method_map.readiness_by_call[transaction, method] - for method in method_map.methods_by_transaction[transaction] - ] - m.d.comb += transaction.runnable.eq(Cat(ready).all()) - - ccs = _graph_ccs(cgr) - m.submodules._transactron_schedulers = ModuleConnector( - *[self.cc_scheduler(method_map, cgr, cc, porder) for cc in ccs] - ) - - method_enables = self._method_enables(method_map) - - for method, transactions in method_map.transactions_by_method.items(): - granted = Cat(transaction.grant & method_enables[transaction][method] for transaction in transactions) - m.d.comb += method.run.eq(granted.any()) - - (method_args, method_runs) = self._method_calls(m, method_map) - - for method in method_map.methods: - if len(method_args[method]) == 1: - m.d.comb += method.data_in.eq(method_args[method][0]) - else: - if method.single_caller: - raise RuntimeError(f"Single-caller method '{method.name}' called more than once") - - runs = Cat(method_runs[method]) - m.d.comb += assign(method.data_in, method.combiner(m, method_args[method], runs), fields=AssignType.ALL) - - if "TRANSACTRON_VERBOSE" in environ: - self.print_info(cgr, porder, ccs, method_map) - - return m - - def print_info( - self, cgr: TransactionGraph, porder: PriorityOrder, ccs: list[GraphCC["Transaction"]], method_map: MethodMap - ): - print("Transactron statistics") - print(f"\tMethods: {len(method_map.methods)}") - print(f"\tTransactions: {len(method_map.transactions)}") - print(f"\tIndependent subgraphs: {len(ccs)}") - print(f"\tAvg callers per method: {average_dict_of_lists(method_map.transactions_by_method):.2f}") - print(f"\tAvg conflicts per transaction: {average_dict_of_lists(cgr):.2f}") - print("") - print("Transaction subgraphs") - for cc in ccs: - ccl = list(cc) - ccl.sort(key=lambda t: porder[t]) - for t in ccl: - print(f"\t{t.name}") - print("") - print("Calling transactions per method") - for m, ts in method_map.transactions_by_method.items(): - print(f"\t{m.owned_name}: {m.src_loc[0]}:{m.src_loc[1]}") - for t in ts: - print(f"\t\t{t.name}: {t.src_loc[0]}:{t.src_loc[1]}") - print("") - print("Called methods per transaction") - for t, ms in method_map.methods_by_transaction.items(): - print(f"\t{t.name}: {t.src_loc[0]}:{t.src_loc[1]}") - for m in ms: - print(f"\t\t{m.owned_name}: {m.src_loc[0]}:{m.src_loc[1]}") - print("") - - def visual_graph(self, fragment): - graph = OwnershipGraph(fragment) - method_map = MethodMap(self.transactions) - for method, transactions in method_map.transactions_by_method.items(): - if len(method.data_in.as_value()) > len(method.data_out.as_value()): - direction = Direction.IN - elif method.data_in.shape().size < method.data_out.shape().size: - direction = Direction.OUT - else: - direction = Direction.INOUT - graph.insert_node(method) - for transaction in transactions: - graph.insert_node(transaction) - graph.insert_edge(transaction, method, direction) - - return graph - - def debug_signals(self) -> SignalBundle: - method_map = MethodMap(self.transactions) - cgr, _ = TransactionManager._conflict_graph(method_map) - - def transaction_debug(t: Transaction): - return ( - [t.request, t.grant] - + [m.ready for m in method_map.methods_by_transaction[t]] - + [t2.grant for t2 in cgr[t]] - ) - - def method_debug(m: Method): - return [m.ready, m.run, {t.name: transaction_debug(t) for t in method_map.transactions_by_method[m]}] - - return { - "transactions": {t.name: transaction_debug(t) for t in method_map.transactions}, - "methods": {m.owned_name: method_debug(m) for m in method_map.methods}, - } - - -class TransactionModule(Elaboratable): - """ - `TransactionModule` is used as wrapper on `Elaboratable` classes, - which adds support for transactions. It creates a - `TransactionManager` which will handle transaction scheduling - and can be used in definition of `Method`\\s and `Transaction`\\s. - The `TransactionManager` is stored in a `DependencyManager`. - """ - - def __init__( - self, - elaboratable: HasElaborate, - dependency_manager: Optional[DependencyManager] = None, - transaction_manager: Optional[TransactionManager] = None, - ): - """ - Parameters - ---------- - elaboratable: HasElaborate - The `Elaboratable` which should be wrapped to add support for - transactions and methods. - dependency_manager: DependencyManager, optional - The `DependencyManager` to use inside the transaction module. - If omitted, a new one is created. - transaction_manager: TransactionManager, optional - The `TransactionManager` to use inside the transaction module. - If omitted, a new one is created. - """ - if transaction_manager is None: - transaction_manager = TransactionManager() - if dependency_manager is None: - dependency_manager = DependencyManager() - self.manager = dependency_manager - self.manager.add_dependency(TransactionManagerKey(), transaction_manager) - self.elaboratable = elaboratable - - def context(self) -> DependencyContext: - return DependencyContext(self.manager) - - def elaborate(self, platform): - with silence_mustuse(self.manager.get_dependency(TransactionManagerKey())): - with self.context(): - elaboratable = Fragment.get(self.elaboratable, platform) - - m = Module() - - m.submodules.main_module = elaboratable - m.submodules.transactionManager = self.transaction_manager = self.manager.get_dependency( - TransactionManagerKey() - ) - - return m - - -class TransactionComponent(TransactionModule, Component): - """Top-level component for Transactron projects. - - The `TransactronComponent` is a wrapper on `Component` classes, - which adds Transactron support for the wrapped class. The use - case is to wrap a top-level module of the project, and pass the - wrapped module for simulation, HDL generation or synthesis. - The ports of the wrapped component are forwarded to the wrapper. - - It extends the functionality of `TransactionModule`. - """ - - def __init__( - self, - component: AbstractComponent, - dependency_manager: Optional[DependencyManager] = None, - transaction_manager: Optional[TransactionManager] = None, - ): - """ - Parameters - ---------- - component: Component - The `Component` which should be wrapped to add support for - transactions and methods. - dependency_manager: DependencyManager, optional - The `DependencyManager` to use inside the transaction component. - If omitted, a new one is created. - transaction_manager: TransactionManager, optional - The `TransactionManager` to use inside the transaction component. - If omitted, a new one is created. - """ - TransactionModule.__init__(self, component, dependency_manager, transaction_manager) - Component.__init__(self, component.signature) - - def elaborate(self, platform): - m = super().elaborate(platform) - - assert isinstance(self.elaboratable, Component) # for typing - connect(m, flipped(self), self.elaboratable) - - return m diff --git a/transactron/core/method.py b/transactron/core/method.py deleted file mode 100644 index b5d573fcd..000000000 --- a/transactron/core/method.py +++ /dev/null @@ -1,315 +0,0 @@ -from collections.abc import Sequence -from transactron.utils import * -from amaranth import * -from amaranth import tracer -from typing import Optional, Callable, Iterator, TYPE_CHECKING -from .transaction_base import * -from .sugar import def_method -from contextlib import contextmanager -from transactron.utils.assign import AssignArg - -if TYPE_CHECKING: - from .tmodule import TModule - -__all__ = ["Method"] - - -class Method(TransactionBase): - """Transactional method. - - A `Method` serves to interface a module with external `Transaction`\\s - or `Method`\\s. It can be called by at most once in a given clock cycle. - When a given `Method` is required by multiple `Transaction`\\s - (either directly, or indirectly via another `Method`) simultenaously, - at most one of them is granted by the `TransactionManager`, and the rest - of them must wait. (Non-exclusive methods are an exception to this - behavior.) Calling a `Method` always takes a single clock cycle. - - Data is combinationally transferred between to and from `Method`\\s - using Amaranth structures (`View` with a `StructLayout`). The transfer - can take place in both directions at the same time: from the called - `Method` to the caller (`data_out`) and from the caller to the called - `Method` (`data_in`). - - A module which defines a `Method` should use `body` or `def_method` - to describe the method's effect on the module state. - - Attributes - ---------- - name: str - Name of this `Method`. - ready: Signal, in - Signals that the method is ready to run in the current cycle. - Typically defined by calling `body`. - run: Signal, out - Signals that the method is called in the current cycle by some - `Transaction`. Defined by the `TransactionManager`. - data_in: MethodStruct, out - Contains the data passed to the `Method` by the caller - (a `Transaction` or another `Method`). - data_out: MethodStruct, in - Contains the data passed from the `Method` to the caller - (a `Transaction` or another `Method`). Typically defined by - calling `body`. - """ - - def __init__( - self, - *, - name: Optional[str] = None, - i: MethodLayout = (), - o: MethodLayout = (), - nonexclusive: bool = False, - combiner: Optional[Callable[[Module, Sequence[MethodStruct], Value], AssignArg]] = None, - single_caller: bool = False, - src_loc: int | SrcLoc = 0, - ): - """ - Parameters - ---------- - name: str or None - Name hint for this `Method`. If `None` (default) the name is - inferred from the variable name this `Method` is assigned to. - i: method layout - The format of `data_in`. - o: method layout - The format of `data_out`. - nonexclusive: bool - If true, the method is non-exclusive: it can be called by multiple - transactions in the same clock cycle. If such a situation happens, - the method still is executed only once, and each of the callers - receive its output. Nonexclusive methods cannot have inputs. - combiner: (Module, Sequence[MethodStruct], Value) -> AssignArg - If `nonexclusive` is true, the combiner function combines the - arguments from multiple calls to this method into a single - argument, which is passed to the method body. The third argument - is a bit vector, whose n-th bit is 1 if the n-th call is active - in a given cycle. - single_caller: bool - If true, this method is intended to be called from a single - transaction. An error will be thrown if called from multiple - transactions. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - super().__init__(src_loc=get_src_loc(src_loc)) - - def default_combiner(m: Module, args: Sequence[MethodStruct], runs: Value) -> AssignArg: - ret = Signal(from_method_layout(i)) - for k in OneHotSwitchDynamic(m, runs): - m.d.comb += ret.eq(args[k]) - return ret - - self.owner, owner_name = get_caller_class_name(default="$method") - self.name = name or tracer.get_var_name(depth=2, default=owner_name) - self.ready = Signal(name=self.owned_name + "_ready") - self.run = Signal(name=self.owned_name + "_run") - self.data_in: MethodStruct = Signal(from_method_layout(i)) - self.data_out: MethodStruct = Signal(from_method_layout(o)) - self.nonexclusive = nonexclusive - self.combiner: Callable[[Module, Sequence[MethodStruct], Value], AssignArg] = combiner or default_combiner - self.single_caller = single_caller - self.validate_arguments: Optional[Callable[..., ValueLike]] = None - if nonexclusive: - assert len(self.data_in.as_value()) == 0 or combiner is not None - - @property - def layout_in(self): - return self.data_in.shape() - - @property - def layout_out(self): - return self.data_out.shape() - - @staticmethod - def like(other: "Method", *, name: Optional[str] = None, src_loc: int | SrcLoc = 0) -> "Method": - """Constructs a new `Method` based on another. - - The returned `Method` has the same input/output data layouts as the - `other` `Method`. - - Parameters - ---------- - other : Method - The `Method` which serves as a blueprint for the new `Method`. - name : str, optional - Name of the new `Method`. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - - Returns - ------- - Method - The freshly constructed `Method`. - """ - return Method(name=name, i=other.layout_in, o=other.layout_out, src_loc=get_src_loc(src_loc)) - - def proxy(self, m: "TModule", method: "Method"): - """Define as a proxy for another method. - - The calls to this method will be forwarded to `method`. - - Parameters - ---------- - m : TModule - Module in which operations on signals should be executed, - `proxy` uses the combinational domain only. - method : Method - Method for which this method is a proxy for. - """ - - @def_method(m, self, ready=method.ready) - def _(arg): - return method(m, arg) - - @contextmanager - def body( - self, - m: "TModule", - *, - ready: ValueLike = C(1), - out: ValueLike = C(0, 0), - validate_arguments: Optional[Callable[..., ValueLike]] = None, - ) -> Iterator[MethodStruct]: - """Define method body - - The `body` context manager can be used to define the actions - performed by a `Method` when it's run. Each assignment added to - a domain under `body` is guarded by the `run` signal. - Combinational assignments which do not need to be guarded by `run` - can be added to `m.d.av_comb` or `m.d.top_comb` instead of `m.d.comb`. - `Method` calls can be performed under `body`. - - Parameters - ---------- - m : TModule - Module in which operations on signals should be executed, - `body` uses the combinational domain only. - ready : Signal, in - Signal to indicate if the method is ready to be run. By - default it is `Const(1)`, so the method is always ready. - Assigned combinationially to the `ready` attribute. - out : Value, in - Data generated by the `Method`, which will be passed to - the caller (a `Transaction` or another `Method`). Assigned - combinationally to the `data_out` attribute. - validate_arguments: Optional[Callable[..., ValueLike]] - Function that takes input arguments used to call the method - and checks whether the method can be called with those arguments. - It instantiates a combinational circuit for each - method caller. By default, there is no function, so all arguments - are accepted. - - Returns - ------- - data_in : Record, out - Data passed from the caller (a `Transaction` or another - `Method`) to this `Method`. - - Examples - -------- - .. highlight:: python - .. code-block:: python - - m = Module() - my_sum_method = Method(i = Layout([("arg1",8),("arg2",8)])) - sum = Signal(16) - with my_sum_method.body(m, out = sum) as data_in: - m.d.comb += sum.eq(data_in.arg1 + data_in.arg2) - """ - if self.defined: - raise RuntimeError(f"Method '{self.name}' already defined") - self.def_order = next(TransactionBase.def_counter) - self.validate_arguments = validate_arguments - - m.d.av_comb += self.ready.eq(ready) - m.d.top_comb += self.data_out.eq(out) - with self.context(m): - with m.AvoidedIf(self.run): - yield self.data_in - - def _validate_arguments(self, arg_rec: MethodStruct) -> ValueLike: - if self.validate_arguments is not None: - return self.ready & method_def_helper(self, self.validate_arguments, arg_rec) - return self.ready - - def __call__( - self, m: "TModule", arg: Optional[AssignArg] = None, enable: ValueLike = C(1), /, **kwargs: AssignArg - ) -> MethodStruct: - """Call a method. - - Methods can only be called from transaction and method bodies. - Calling a `Method` marks, for the purpose of transaction scheduling, - the dependency between the calling context and the called `Method`. - It also connects the method's inputs to the parameters and the - method's outputs to the return value. - - Parameters - ---------- - m : TModule - Module in which operations on signals should be executed, - arg : Value or dict of Values - Call argument. Can be passed as a `View` of the method's - input layout or as a dictionary. Alternative syntax uses - keyword arguments. - enable : Value - Configures the call as enabled in the current clock cycle. - Disabled calls still lock the called method in transaction - scheduling. Calls are by default enabled. - **kwargs : Value or dict of Values - Allows to pass method arguments using keyword argument - syntax. Equivalent to passing a dict as the argument. - - Returns - ------- - data_out : MethodStruct - The result of the method call. - - Examples - -------- - .. highlight:: python - .. code-block:: python - - m = Module() - with Transaction().body(m): - ret = my_sum_method(m, arg1=2, arg2=3) - - Alternative syntax: - - .. highlight:: python - .. code-block:: python - - with Transaction().body(m): - ret = my_sum_method(m, {"arg1": 2, "arg2": 3}) - """ - arg_rec = Signal.like(self.data_in) - - if arg is not None and kwargs: - raise ValueError(f"Method '{self.name}' call with both keyword arguments and legacy record argument") - - if arg is None: - arg = kwargs - - enable_sig = Signal(name=self.owned_name + "_enable") - m.d.av_comb += enable_sig.eq(enable) - m.d.top_comb += assign(arg_rec, arg, fields=AssignType.ALL) - - caller = TransactionBase.get() - if not all(ctrl_path.exclusive_with(m.ctrl_path) for ctrl_path, _, _ in caller.method_calls[self]): - raise RuntimeError(f"Method '{self.name}' can't be called twice from the same caller '{caller.name}'") - caller.method_calls[self].append((m.ctrl_path, arg_rec, enable_sig)) - - if self not in caller.method_uses: - arg_rec_use = Signal(self.layout_in) - arg_rec_enable_sig = Signal() - caller.method_uses[self] = (arg_rec_use, arg_rec_enable_sig) - - return self.data_out - - def __repr__(self) -> str: - return "(method {})".format(self.name) - - def debug_signals(self) -> SignalBundle: - return [self.ready, self.run, self.data_in, self.data_out] diff --git a/transactron/core/schedulers.py b/transactron/core/schedulers.py deleted file mode 100644 index 856d4450b..000000000 --- a/transactron/core/schedulers.py +++ /dev/null @@ -1,77 +0,0 @@ -from amaranth import * -from typing import TYPE_CHECKING -from transactron.utils import * - -if TYPE_CHECKING: - from .manager import MethodMap, TransactionGraph, TransactionGraphCC, PriorityOrder - -__all__ = ["eager_deterministic_cc_scheduler", "trivial_roundrobin_cc_scheduler"] - - -def eager_deterministic_cc_scheduler( - method_map: "MethodMap", gr: "TransactionGraph", cc: "TransactionGraphCC", porder: "PriorityOrder" -) -> Module: - """eager_deterministic_cc_scheduler - - This function generates an eager scheduler for the transaction - subsystem. It isn't fair, because it starts transactions using - transaction index in `cc` as a priority. Transaction with the lowest - index has the highest priority. - - If there are two different transactions which have no conflicts then - they will be started concurrently. - - Parameters - ---------- - manager : TransactionManager - TransactionManager which uses this instance of scheduler for - arbitrating which agent should get a grant signal. - gr : TransactionGraph - Graph of conflicts between transactions, where vertices are transactions and edges are conflicts. - cc : Set[Transaction] - Connected components of the graph `gr` for which scheduler - should be generated. - porder : PriorityOrder - Linear ordering of transactions which is consistent with priority constraints. - """ - m = Module() - ccl = list(cc) - ccl.sort(key=lambda transaction: porder[transaction]) - for k, transaction in enumerate(ccl): - conflicts = [ccl[j].grant for j in range(k) if ccl[j] in gr[transaction]] - noconflict = ~Cat(conflicts).any() - m.d.comb += transaction.grant.eq(transaction.request & transaction.runnable & noconflict) - return m - - -def trivial_roundrobin_cc_scheduler( - method_map: "MethodMap", gr: "TransactionGraph", cc: "TransactionGraphCC", porder: "PriorityOrder" -) -> Module: - """trivial_roundrobin_cc_scheduler - - This function generates a simple round-robin scheduler for the transaction - subsystem. In a one cycle there will be at most one transaction granted - (in a given connected component of the conflict graph), even if there is - another ready, non-conflicting, transaction. It is mainly for testing - purposes. - - Parameters - ---------- - manager : TransactionManager - TransactionManager which uses this instance of scheduler for - arbitrating which agent should get grant signal. - gr : TransactionGraph - Graph of conflicts between transactions, where vertices are transactions and edges are conflicts. - cc : Set[Transaction] - Connected components of the graph `gr` for which scheduler - should be generated. - porder : PriorityOrder - Linear ordering of transactions which is consistent with priority constraints. - """ - m = Module() - sched = Scheduler(len(cc)) - m.submodules.scheduler = sched - for k, transaction in enumerate(cc): - m.d.comb += sched.requests[k].eq(transaction.request & transaction.runnable) - m.d.comb += transaction.grant.eq(sched.grant[k] & sched.valid) - return m diff --git a/transactron/core/sugar.py b/transactron/core/sugar.py deleted file mode 100644 index 640cddbb5..000000000 --- a/transactron/core/sugar.py +++ /dev/null @@ -1,180 +0,0 @@ -from collections.abc import Sequence, Callable -from amaranth import * -from typing import TYPE_CHECKING, Optional, Concatenate, ParamSpec -from transactron.utils import * -from transactron.utils.assign import AssignArg -from functools import partial - -if TYPE_CHECKING: - from .tmodule import TModule - from .method import Method - -__all__ = ["def_method", "def_methods"] - - -P = ParamSpec("P") - - -def def_method( - m: "TModule", - method: "Method", - ready: ValueLike = C(1), - validate_arguments: Optional[Callable[..., ValueLike]] = None, -): - """Define a method. - - This decorator allows to define transactional methods in an - elegant way using Python's `def` syntax. Internally, `def_method` - uses `Method.body`. - - The decorated function should take keyword arguments corresponding to the - fields of the method's input layout. The `**kwargs` syntax is supported. - Alternatively, it can take one argument named `arg`, which will be a - structure with input signals. - - The returned value can be either a structure with the method's output layout - or a dictionary of outputs. - - Parameters - ---------- - m: TModule - Module in which operations on signals should be executed. - method: Method - The method whose body is going to be defined. - ready: Signal - Signal to indicate if the method is ready to be run. By - default it is `Const(1)`, so the method is always ready. - Assigned combinationally to the `ready` attribute. - validate_arguments: Optional[Callable[..., ValueLike]] - Function that takes input arguments used to call the method - and checks whether the method can be called with those arguments. - It instantiates a combinational circuit for each - method caller. By default, there is no function, so all arguments - are accepted. - - Examples - -------- - .. highlight:: python - .. code-block:: python - - m = Module() - my_sum_method = Method(i=[("arg1",8),("arg2",8)], o=[("res",8)]) - @def_method(m, my_sum_method) - def _(arg1, arg2): - return arg1 + arg2 - - Alternative syntax (keyword args in dictionary): - - .. highlight:: python - .. code-block:: python - - @def_method(m, my_sum_method) - def _(**args): - return args["arg1"] + args["arg2"] - - Alternative syntax (arg structure): - - .. highlight:: python - .. code-block:: python - - @def_method(m, my_sum_method) - def _(arg): - return {"res": arg.arg1 + arg.arg2} - """ - - def decorator(func: Callable[..., Optional[AssignArg]]): - out = Signal(method.layout_out) - ret_out = None - - with method.body(m, ready=ready, out=out, validate_arguments=validate_arguments) as arg: - ret_out = method_def_helper(method, func, arg) - - if ret_out is not None: - m.d.top_comb += assign(out, ret_out, fields=AssignType.ALL) - - return decorator - - -def def_methods( - m: "TModule", - methods: Sequence["Method"], - ready: Callable[[int], ValueLike] = lambda _: C(1), - validate_arguments: Optional[Callable[..., ValueLike]] = None, -): - """Decorator for defining similar methods - - This decorator is a wrapper over `def_method`, which allows you to easily - define multiple similar methods in a loop. - - The function over which this decorator is applied, should always expect - at least one argument, as the index of the method will be passed as the - first argument to the function. - - This is a syntax sugar equivalent to: - - .. highlight:: python - .. code-block:: python - - for i in range(len(my_methods)): - @def_method(m, my_methods[i]) - def _(arg): - ... - - Parameters - ---------- - m: TModule - Module in which operations on signals should be executed. - methods: Sequence[Method] - The methods whose body is going to be defined. - ready: Callable[[int], Value] - A `Callable` that takes the index in the form of an `int` of the currently defined method - and produces a `Value` describing whether the method is ready to be run. - When omitted, each defined method is always ready. Assigned combinationally to the `ready` attribute. - validate_arguments: Optional[Callable[Concatenate[int, ...], ValueLike]] - Function that takes input arguments used to call the method - and checks whether the method can be called with those arguments. - It instantiates a combinational circuit for each - method caller. By default, there is no function, so all arguments - are accepted. - - Examples - -------- - Define three methods with the same body: - - .. highlight:: python - .. code-block:: python - - m = TModule() - my_sum_methods = [Method(i=[("arg1",8),("arg2",8)], o=[("res",8)]) for _ in range(3)] - @def_methods(m, my_sum_methods) - def _(_, arg1, arg2): - return arg1 + arg2 - - Define three methods with different bodies parametrized with the index of the method: - - .. highlight:: python - .. code-block:: python - - m = TModule() - my_sum_methods = [Method(i=[("arg1",8),("arg2",8)], o=[("res",8)]) for _ in range(3)] - @def_methods(m, my_sum_methods) - def _(index : int, arg1, arg2): - return arg1 + arg2 + index - - Define three methods with different ready signals: - - .. highlight:: python - .. code-block:: python - - @def_methods(m, my_filter_read_methods, ready_list=lambda i: fifo.head == i) - def _(_): - return fifo.read(m) - """ - - def decorator(func: Callable[Concatenate[int, P], Optional[RecordDict]]): - for i in range(len(methods)): - partial_f = partial(func, i) - partial_vargs = partial(validate_arguments, i) if validate_arguments is not None else None - def_method(m, methods[i], ready(i), partial_vargs)(partial_f) - - return decorator diff --git a/transactron/core/tmodule.py b/transactron/core/tmodule.py deleted file mode 100644 index d4276dce7..000000000 --- a/transactron/core/tmodule.py +++ /dev/null @@ -1,286 +0,0 @@ -from enum import Enum, auto -from dataclasses import dataclass, replace -from amaranth import * -from typing import Optional, Self, NoReturn -from contextlib import contextmanager -from amaranth.hdl._dsl import FSM -from transactron.utils import * - -__all__ = ["TModule"] - - -class _AvoidingModuleBuilderDomain: - """ - A wrapper over Amaranth domain to abstract away internal Amaranth implementation. - It is needed to allow for correctness check in `__setattr__` which uses `isinstance`. - """ - - def __init__(self, amaranth_module_domain): - self._domain = amaranth_module_domain - - def __iadd__(self, assigns: StatementLike) -> Self: - self._domain.__iadd__(assigns) - return self - - -class _AvoidingModuleBuilderDomains: - _m: "TModule" - - def __init__(self, m: "TModule"): - object.__setattr__(self, "_m", m) - - def __getattr__(self, name: str) -> _AvoidingModuleBuilderDomain: - if name == "av_comb": - return _AvoidingModuleBuilderDomain(self._m.avoiding_module.d["comb"]) - elif name == "top_comb": - return _AvoidingModuleBuilderDomain(self._m.top_module.d["comb"]) - else: - return _AvoidingModuleBuilderDomain(self._m.main_module.d[name]) - - def __getitem__(self, name: str) -> _AvoidingModuleBuilderDomain: - return self.__getattr__(name) - - def __setattr__(self, name: str, value): - if not isinstance(value, _AvoidingModuleBuilderDomain): - raise AttributeError(f"Cannot assign 'd.{name}' attribute; did you mean 'd.{name} +='?") - - def __setitem__(self, name: str, value): - return self.__setattr__(name, value) - - -class EnterType(Enum): - """Characterizes stack behavior of Amaranth's context managers for control structures.""" - - #: Used for `m.If`, `m.Switch` and `m.FSM`. - PUSH = auto() - #: Used for `m.Elif` and `m.Else`. - ADD = auto() - #: Used for `m.Case`, `m.Default` and `m.State`. - ENTRY = auto() - - -@dataclass(frozen=True) -class PathEdge: - """Describes an edge in Amaranth's control tree. - - Attributes - ---------- - alt : int - Which alternative (e.g. case of `m.If` or m.Switch`) is described. - par : int - Which parallel control structure (e.g. `m.If` at the same level) is described. - """ - - alt: int = 0 - par: int = 0 - - -@dataclass -class CtrlPath: - """Describes a path in Amaranth's control tree. - - Attributes - ---------- - module : int - Unique number of the module the path refers to. - path : list[PathEdge] - Path in the control tree, starting from the root. - """ - - module: int - path: list[PathEdge] - - def exclusive_with(self, other: "CtrlPath"): - """Decides if this path is mutually exclusive with some other path. - - Paths are mutually exclusive if they refer to the same module and - diverge on different alternatives of the same control structure. - - Arguments - --------- - other : CtrlPath - The other path this path is compared to. - """ - common_prefix = [] - for a, b in zip(self.path, other.path): - if a == b: - common_prefix.append(a) - elif a.par != b.par: - return False - else: - break - - return ( - self.module == other.module - and len(common_prefix) != len(self.path) - and len(common_prefix) != len(other.path) - ) - - -class CtrlPathBuilder: - """Constructs control paths. - - Used internally by `TModule`.""" - - def __init__(self, module: int): - """ - Parameters - ---------- - module: int - Unique module identifier. - """ - self.module = module - self.ctrl_path: list[PathEdge] = [] - self.previous: Optional[PathEdge] = None - - @contextmanager - def enter(self, enter_type=EnterType.PUSH): - et = EnterType - - match enter_type: - case et.ADD: - assert self.previous is not None - self.ctrl_path.append(replace(self.previous, alt=self.previous.alt + 1)) - case et.ENTRY: - self.ctrl_path[-1] = replace(self.ctrl_path[-1], alt=self.ctrl_path[-1].alt + 1) - case et.PUSH: - if self.previous is not None: - self.ctrl_path.append(PathEdge(par=self.previous.par + 1)) - else: - self.ctrl_path.append(PathEdge()) - self.previous = None - try: - yield - finally: - if enter_type in [et.PUSH, et.ADD]: - self.previous = self.ctrl_path.pop() - - def build_ctrl_path(self): - """Returns the current control path.""" - return CtrlPath(self.module, self.ctrl_path[:]) - - -class TModule(ModuleLike, Elaboratable): - """Extended Amaranth module for use with transactions. - - It includes three different combinational domains: - - * `comb` domain, works like the `comb` domain in plain Amaranth modules. - Statements in `comb` are guarded by every condition, including - `AvoidedIf`. This means they are guarded by transaction and method - bodies: they don't execute if the given transaction/method is not run. - * `av_comb` is guarded by all conditions except `AvoidedIf`. This means - they are not guarded by transaction and method bodies. This allows to - reduce the amount of useless multplexers due to transaction use, while - still allowing the use of conditions in transaction/method bodies. - * `top_comb` is unguarded: statements added to this domain always - execute. It can be used to reduce combinational path length due to - multplexers while keeping related combinational and synchronous - statements together. - """ - - __next_uid = 0 - - def __init__(self): - self.main_module = Module() - self.avoiding_module = Module() - self.top_module = Module() - self.d = _AvoidingModuleBuilderDomains(self) - self.submodules = self.main_module.submodules - self.domains = self.main_module.domains - self.fsm: Optional[FSM] = None - self.uid = TModule.__next_uid - self.path_builder = CtrlPathBuilder(self.uid) - TModule.__next_uid += 1 - - @contextmanager - def AvoidedIf(self, cond: ValueLike): # noqa: N802 - with self.main_module.If(cond): - with self.path_builder.enter(EnterType.PUSH): - yield - - @contextmanager - def If(self, cond: ValueLike): # noqa: N802 - with self.main_module.If(cond): - with self.avoiding_module.If(cond): - with self.path_builder.enter(EnterType.PUSH): - yield - - @contextmanager - def Elif(self, cond): # noqa: N802 - with self.main_module.Elif(cond): - with self.avoiding_module.Elif(cond): - with self.path_builder.enter(EnterType.ADD): - yield - - @contextmanager - def Else(self): # noqa: N802 - with self.main_module.Else(): - with self.avoiding_module.Else(): - with self.path_builder.enter(EnterType.ADD): - yield - - @contextmanager - def Switch(self, test: ValueLike): # noqa: N802 - with self.main_module.Switch(test): - with self.avoiding_module.Switch(test): - with self.path_builder.enter(EnterType.PUSH): - yield - - @contextmanager - def Case(self, *patterns: SwitchKey): # noqa: N802 - with self.main_module.Case(*patterns): - with self.avoiding_module.Case(*patterns): - with self.path_builder.enter(EnterType.ENTRY): - yield - - @contextmanager - def Default(self): # noqa: N802 - with self.main_module.Default(): - with self.avoiding_module.Default(): - with self.path_builder.enter(EnterType.ENTRY): - yield - - @contextmanager - def FSM(self, init: Optional[str] = None, domain: str = "sync", name: str = "fsm"): # noqa: N802 - old_fsm = self.fsm - with self.main_module.FSM(init, domain, name) as fsm: - self.fsm = fsm - with self.path_builder.enter(EnterType.PUSH): - yield fsm - self.fsm = old_fsm - - @contextmanager - def State(self, name: str): # noqa: N802 - assert self.fsm is not None - with self.main_module.State(name): - with self.avoiding_module.If(self.fsm.ongoing(name)): - with self.path_builder.enter(EnterType.ENTRY): - yield - - @property - def next(self) -> NoReturn: - raise NotImplementedError - - @next.setter - def next(self, name: str): - self.main_module.next = name - - @property - def ctrl_path(self): - return self.path_builder.build_ctrl_path() - - @property - def _MustUse__silence(self): # noqa: N802 - return self.main_module._MustUse__silence - - @_MustUse__silence.setter - def _MustUse__silence(self, value): # noqa: N802 - self.main_module._MustUse__silence = value # type: ignore - self.avoiding_module._MustUse__silence = value # type: ignore - self.top_module._MustUse__silence = value # type: ignore - - def elaborate(self, platform): - self.main_module.submodules._avoiding_module = self.avoiding_module - self.main_module.submodules._top_module = self.top_module - return self.main_module diff --git a/transactron/core/transaction.py b/transactron/core/transaction.py deleted file mode 100644 index c6f4176ab..000000000 --- a/transactron/core/transaction.py +++ /dev/null @@ -1,115 +0,0 @@ -from transactron.utils import * -from amaranth import * -from amaranth import tracer -from typing import Optional, Iterator, TYPE_CHECKING -from .transaction_base import * -from .keys import * -from contextlib import contextmanager - -if TYPE_CHECKING: - from .tmodule import TModule - from .manager import TransactionManager - -__all__ = ["Transaction"] - - -class Transaction(TransactionBase): - """Transaction. - - A `Transaction` represents a task which needs to be regularly done. - Execution of a `Transaction` always lasts a single clock cycle. - A `Transaction` signals readiness for execution by setting the - `request` signal. If the conditions for its execution are met, it - can be granted by the `TransactionManager`. - - A `Transaction` can, as part of its execution, call a number of - `Method`\\s. A `Transaction` can be granted only if every `Method` - it runs is ready. - - A `Transaction` cannot execute concurrently with another, conflicting - `Transaction`. Conflicts between `Transaction`\\s are either explicit - or implicit. An explicit conflict is added using the `add_conflict` - method. Implicit conflicts arise between pairs of `Transaction`\\s - which use the same `Method`. - - A module which defines a `Transaction` should use `body` to - describe used methods and the transaction's effect on the module state. - The used methods should be called inside the `body`'s - `with` block. - - Attributes - ---------- - name: str - Name of this `Transaction`. - request: Signal, in - Signals that the transaction wants to run. If omitted, the transaction - is always ready. Defined in the constructor. - runnable: Signal, out - Signals that all used methods are ready. - grant: Signal, out - Signals that the transaction is granted by the `TransactionManager`, - and all used methods are called. - """ - - def __init__( - self, *, name: Optional[str] = None, manager: Optional["TransactionManager"] = None, src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - name: str or None - Name hint for this `Transaction`. If `None` (default) the name is - inferred from the variable name this `Transaction` is assigned to. - If the `Transaction` was not assigned, the name is inferred from - the class name where the `Transaction` was constructed. - manager: TransactionManager - The `TransactionManager` controlling this `Transaction`. - If omitted, the manager is received from `TransactionContext`. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - super().__init__(src_loc=get_src_loc(src_loc)) - self.owner, owner_name = get_caller_class_name(default="$transaction") - self.name = name or tracer.get_var_name(depth=2, default=owner_name) - if manager is None: - manager = DependencyContext.get().get_dependency(TransactionManagerKey()) - manager.add_transaction(self) - self.request = Signal(name=self.owned_name + "_request") - self.runnable = Signal(name=self.owned_name + "_runnable") - self.grant = Signal(name=self.owned_name + "_grant") - - @contextmanager - def body(self, m: "TModule", *, request: ValueLike = C(1)) -> Iterator["Transaction"]: - """Defines the `Transaction` body. - - This context manager allows to conveniently define the actions - performed by a `Transaction` when it's granted. Each assignment - added to a domain under `body` is guarded by the `grant` signal. - Combinational assignments which do not need to be guarded by - `grant` can be added to `m.d.top_comb` or `m.d.av_comb` instead of - `m.d.comb`. `Method` calls can be performed under `body`. - - Parameters - ---------- - m: TModule - The module where the `Transaction` is defined. - request: Signal - Indicates that the `Transaction` wants to be executed. By - default it is `Const(1)`, so it wants to be executed in - every clock cycle. - """ - if self.defined: - raise RuntimeError(f"Transaction '{self.name}' already defined") - self.def_order = next(TransactionBase.def_counter) - - m.d.av_comb += self.request.eq(request) - with self.context(m): - with m.AvoidedIf(self.grant): - yield self - - def __repr__(self) -> str: - return "(transaction {})".format(self.name) - - def debug_signals(self) -> SignalBundle: - return [self.request, self.runnable, self.grant] diff --git a/transactron/core/transaction_base.py b/transactron/core/transaction_base.py deleted file mode 100644 index be4fe6f93..000000000 --- a/transactron/core/transaction_base.py +++ /dev/null @@ -1,209 +0,0 @@ -from collections import defaultdict -from contextlib import contextmanager -from enum import Enum, auto -from itertools import count -from typing import ( - ClassVar, - TypeAlias, - TypedDict, - Union, - TypeVar, - Protocol, - Self, - runtime_checkable, - TYPE_CHECKING, - Iterator, -) -from amaranth import * - -from .tmodule import TModule, CtrlPath -from transactron.graph import Owned -from transactron.utils import * - -if TYPE_CHECKING: - from .method import Method - from .transaction import Transaction - -__all__ = ["TransactionBase", "Priority"] - -TransactionOrMethod: TypeAlias = Union["Transaction", "Method"] -TransactionOrMethodBound = TypeVar("TransactionOrMethodBound", "Transaction", "Method") - - -class Priority(Enum): - #: Conflicting transactions/methods don't have a priority order. - UNDEFINED = auto() - #: Left transaction/method is prioritized over the right one. - LEFT = auto() - #: Right transaction/method is prioritized over the left one. - RIGHT = auto() - - -class RelationBase(TypedDict): - end: TransactionOrMethod - priority: Priority - conflict: bool - silence_warning: bool - - -class Relation(RelationBase): - start: TransactionOrMethod - - -@runtime_checkable -class TransactionBase(Owned, Protocol): - stack: ClassVar[list[Union["Transaction", "Method"]]] = [] - def_counter: ClassVar[count] = count() - def_order: int - defined: bool = False - name: str - src_loc: SrcLoc - method_uses: dict["Method", tuple[MethodStruct, Signal]] - method_calls: defaultdict["Method", list[tuple[CtrlPath, MethodStruct, ValueLike]]] - relations: list[RelationBase] - simultaneous_list: list[TransactionOrMethod] - independent_list: list[TransactionOrMethod] - ctrl_path: CtrlPath = CtrlPath(-1, []) - - def __init__(self, *, src_loc: int | SrcLoc): - self.src_loc = get_src_loc(src_loc) - self.method_uses = {} - self.method_calls = defaultdict(list) - self.relations = [] - self.simultaneous_list = [] - self.independent_list = [] - - def add_conflict(self, end: TransactionOrMethod, priority: Priority = Priority.UNDEFINED) -> None: - """Registers a conflict. - - Record that that the given `Transaction` or `Method` cannot execute - simultaneously with this `Method` or `Transaction`. Typical reason - is using a common resource (register write or memory port). - - Parameters - ---------- - end: Transaction or Method - The conflicting `Transaction` or `Method` - priority: Priority, optional - Is one of conflicting `Transaction`\\s or `Method`\\s prioritized? - Defaults to undefined priority relation. - """ - self.relations.append( - RelationBase(end=end, priority=priority, conflict=True, silence_warning=self.owner != end.owner) - ) - - def schedule_before(self, end: TransactionOrMethod) -> None: - """Adds a priority relation. - - Record that that the given `Transaction` or `Method` needs to be - scheduled before this `Method` or `Transaction`, without adding - a conflict. Typical reason is data forwarding. - - Parameters - ---------- - end: Transaction or Method - The other `Transaction` or `Method` - """ - self.relations.append( - RelationBase(end=end, priority=Priority.LEFT, conflict=False, silence_warning=self.owner != end.owner) - ) - - def simultaneous(self, *others: TransactionOrMethod) -> None: - """Adds simultaneity relations. - - The given `Transaction`\\s or `Method``\\s will execute simultaneously - (in the same clock cycle) with this `Transaction` or `Method`. - - Parameters - ---------- - *others: Transaction or Method - The `Transaction`\\s or `Method`\\s to be executed simultaneously. - """ - self.simultaneous_list += others - - def simultaneous_alternatives(self, *others: TransactionOrMethod) -> None: - """Adds exclusive simultaneity relations. - - Each of the given `Transaction`\\s or `Method``\\s will execute - simultaneously (in the same clock cycle) with this `Transaction` or - `Method`. However, each of the given `Transaction`\\s or `Method`\\s - will be separately considered for execution. - - Parameters - ---------- - *others: Transaction or Method - The `Transaction`\\s or `Method`\\s to be executed simultaneously, - but mutually exclusive, with this `Transaction` or `Method`. - """ - self.simultaneous(*others) - others[0]._independent(*others[1:]) - - def _independent(self, *others: TransactionOrMethod) -> None: - """Adds independence relations. - - This `Transaction` or `Method`, together with all the given - `Transaction`\\s or `Method`\\s, will never be considered (pairwise) - for simultaneous execution. - - Warning: this function is an implementation detail, do not use in - user code. - - Parameters - ---------- - *others: Transaction or Method - The `Transaction`\\s or `Method`\\s which, together with this - `Transaction` or `Method`, need to be independently considered - for execution. - """ - self.independent_list += others - - @contextmanager - def context(self: TransactionOrMethodBound, m: TModule) -> Iterator[TransactionOrMethodBound]: - self.ctrl_path = m.ctrl_path - - parent = TransactionBase.peek() - if parent is not None: - parent.schedule_before(self) - - TransactionBase.stack.append(self) - - try: - yield self - finally: - TransactionBase.stack.pop() - self.defined = True - - def _set_method_uses(self, m: ModuleLike): - for method, calls in self.method_calls.items(): - arg_rec, enable_sig = self.method_uses[method] - if len(calls) == 1: - m.d.comb += arg_rec.eq(calls[0][1]) - m.d.comb += enable_sig.eq(calls[0][2]) - else: - call_ens = Cat([en for _, _, en in calls]) - - for i in OneHotSwitchDynamic(m, call_ens): - m.d.comb += arg_rec.eq(calls[i][1]) - m.d.comb += enable_sig.eq(1) - - @classmethod - def get(cls) -> Self: - ret = cls.peek() - if ret is None: - raise RuntimeError("No current body") - return ret - - @classmethod - def peek(cls) -> Optional[Self]: - if not TransactionBase.stack: - return None - if not isinstance(TransactionBase.stack[-1], cls): - raise RuntimeError(f"Current body not a {cls.__name__}") - return TransactionBase.stack[-1] - - @property - def owned_name(self): - if self.owner is not None and self.owner.__class__.__name__ != self.name: - return f"{self.owner.__class__.__name__}_{self.name}" - else: - return self.name diff --git a/transactron/graph.py b/transactron/graph.py deleted file mode 100644 index 709ba8724..000000000 --- a/transactron/graph.py +++ /dev/null @@ -1,249 +0,0 @@ -""" -Utilities for extracting dependency graphs from Amaranth designs. -""" - -from enum import IntFlag -from collections import defaultdict -from typing import Literal, Optional, Protocol - -from amaranth import Elaboratable, Fragment -from .tracing import TracingFragment - - -class Owned(Protocol): - name: str - owner: Optional[Elaboratable] - - -class Direction(IntFlag): - NONE = 0 - IN = 1 - OUT = 2 - INOUT = 3 - - -class OwnershipGraph: - mermaid_direction = ["---", "-->", "<--", "<-->"] - - def __init__(self, root): - self.class_counters: defaultdict[type, int] = defaultdict(int) - self.owned_counters: defaultdict[tuple[int, str], int] = defaultdict(int) - self.names: dict[int, str] = {} - self.owned_names: dict[int, str] = {} - self.hier: dict[int, str] = {} - self.labels: dict[int, str] = {} - self.graph: dict[int, list[int]] = {} - self.edges: list[tuple[Owned, Owned, Direction]] = [] - self.owned: defaultdict[int, set[Owned]] = defaultdict(set) - self.stray: set[int] = set() - self.remember(root) - - def remember(self, owner: Elaboratable) -> int: - while hasattr(owner, "_tracing_original"): - owner = owner._tracing_original # type: ignore - owner_id = id(owner) - if owner_id not in self.names: - tp = type(owner) - count = self.class_counters[tp] - self.class_counters[tp] = count + 1 - - name = tp.__name__ - if count: - name += str(count) - self.names[owner_id] = name - self.graph[owner_id] = [] - while True: - for field, obj in vars(owner).items(): - if isinstance(obj, Elaboratable) and not field.startswith("_"): - self.remember_field(owner_id, field, obj) - if isinstance(owner, Fragment): - assert isinstance(owner, TracingFragment) - for obj, field, _ in owner.subfragments: - self.remember_field(owner_id, field, obj) - try: - owner = owner._elaborated # type: ignore - except AttributeError: - break - return owner_id - - def remember_field(self, owner_id: int, field: str, obj: Elaboratable): - while hasattr(obj, "_tracing_original"): - obj = obj._tracing_original # type: ignore - obj_id = id(obj) - if obj_id == owner_id or obj_id in self.labels: - return - self.labels[obj_id] = f"{field} {obj.__class__.__name__}" - self.graph[owner_id].append(obj_id) - self.remember(obj) - - def insert_node(self, obj: Owned): - assert obj.owner is not None - owner_id = self.remember(obj.owner) - self.owned[owner_id].add(obj) - - def insert_edge(self, fr: Owned, to: Owned, direction: Direction = Direction.OUT): - self.edges.append((fr, to, direction)) - - def get_name(self, obj: Owned) -> str: - assert obj.owner is not None - obj_id = id(obj) - name = self.owned_names.get(obj_id) - if name is not None: - return name - owner_id = self.remember(obj.owner) - count = self.owned_counters[(owner_id, obj.name)] - self.owned_counters[(owner_id, obj.name)] = count + 1 - suffix = str(count) if count else "" - name = self.owned_names[obj_id] = f"{self.names[owner_id]}_{obj.name}{suffix}" - return name - - def get_hier_name(self, obj: Owned) -> str: - """ - Get hierarchical name. - Might raise KeyError if not yet hierarchized. - """ - name = self.get_name(obj) - owner_id = id(obj.owner) - hier = self.hier[owner_id] - return f"{hier}.{name}" - - def prune(self, owner: Optional[int] = None): - """ - Mark all empty subgraphs. - """ - if owner is None: - backup = self.graph.copy() - for owner in self.names: - if owner not in self.labels: - self.prune(owner) - self.graph = backup - return - - subowners = self.graph.pop(owner) - flag = bool(self.owned[owner]) - for subowner in subowners: - if subowner in self.graph: - flag |= self.prune(subowner) - - if not flag: - self.stray.add(owner) - - return flag - - def dump(self, fp, format: Literal["dot", "elk", "mermaid"]): - dumper = getattr(self, "dump_" + format) - dumper(fp) - - def dump_dot(self, fp, owner: Optional[int] = None, indent: str = ""): - if owner is None: - fp.write("digraph G {\n") - for owner in self.names: - if owner not in self.labels: - self.dump_dot(fp, owner, indent) - for fr, to, direction in self.edges: - if direction == Direction.OUT: - fr, to = to, fr - - caller_name = self.get_name(fr) - callee_name = self.get_name(to) - fp.write(f"{caller_name} -> {callee_name}\n") - fp.write("}\n") - return - - subowners = self.graph.pop(owner) - if owner in self.stray: - return - indent += " " - owned = self.owned[owner] - fp.write(f"{indent}subgraph cluster_{self.names[owner]} {{\n") - fp.write(f'{indent} label="{self.labels.get(owner, self.names[owner])}";\n') - for x in owned: - fp.write(f'{indent} {self.get_name(x)} [label="{x.name}"];\n') - for subowner in subowners: - if subowner in self.graph: - self.dump_dot(fp, subowner, indent) - fp.write(f"{indent}}}\n") - - def dump_elk(self, fp, owner: Optional[int] = None, indent: str = ""): - if owner is None: - fp.write(f"{indent}hierarchyHandling: INCLUDE_CHILDREN\n") - fp.write(f"{indent}elk.direction: DOWN\n") - for owner in self.names: - if owner not in self.labels: - self.dump_elk(fp, owner, indent) - return - - hier = self.hier.setdefault(owner, self.names[owner]) - - subowners = self.graph.pop(owner) - if owner in self.stray: - return - owned = self.owned[owner] - fp.write(f"{indent}node {self.names[owner]} {{\n") - fp.write(f"{indent} considerModelOrder.components: INSIDE_PORT_SIDE_GROUPS\n") - fp.write(f'{indent} nodeSize.constraints: "[PORTS, PORT_LABELS, MINIMUM_SIZE]"\n') - fp.write(f'{indent} nodeLabels.placement: "[H_LEFT, V_TOP, OUTSIDE]"\n') - fp.write(f'{indent} portLabels.placement: "[INSIDE]"\n') - fp.write(f"{indent} feedbackEdges: true\n") - fp.write(f'{indent} label "{self.labels.get(owner, self.names[owner])}"\n') - for x in owned: - if x.__class__.__name__ == "Method": - fp.write(f'{indent} port {self.get_name(x)} {{ label "{x.name}" }}\n') - else: - fp.write(f"{indent} node {self.get_name(x)} {{\n") - fp.write(f'{indent} nodeSize.constraints: "[NODE_LABELS, MINIMUM_SIZE]"\n') - fp.write(f'{indent} nodeLabels.placement: "[H_CENTER, V_CENTER, INSIDE]"\n') - fp.write(f'{indent} label "{x.name}"\n') - fp.write(f"{indent} }}\n") - for subowner in subowners: - if subowner in self.graph: - self.hier[subowner] = f"{hier}.{self.names[subowner]}" - self.dump_elk(fp, subowner, indent + " ") - - # reverse iteration so that deleting works - for i, (fr, to, direction) in reversed(list(enumerate(self.edges))): - if direction == Direction.OUT: - fr, to = to, fr - - try: - caller_name = self.get_hier_name(fr) - callee_name = self.get_hier_name(to) - except KeyError: - continue - - # only output edges belonging here - if caller_name[: len(hier)] == callee_name[: len(hier)] == hier: - caller_name = caller_name[len(hier) + 1 :] - callee_name = callee_name[len(hier) + 1 :] - del self.edges[i] - fp.write(f"{indent} edge {caller_name} -> {callee_name}\n") - - fp.write(f"{indent}}}\n") - - def dump_mermaid(self, fp, owner: Optional[int] = None, indent: str = ""): - if owner is None: - fp.write("flowchart TB\n") - for owner in self.names: - if owner not in self.labels: - self.dump_mermaid(fp, owner, indent) - for fr, to, direction in self.edges: - if direction == Direction.OUT: - fr, to, direction = to, fr, Direction.IN - - caller_name = self.get_name(fr) - callee_name = self.get_name(to) - fp.write(f"{caller_name} {self.mermaid_direction[direction]} {callee_name}\n") - return - - subowners = self.graph.pop(owner) - if owner in self.stray: - return - indent += " " - owned = self.owned[owner] - fp.write(f'{indent}subgraph {self.names[owner]}["{self.labels.get(owner, self.names[owner])}"]\n') - for x in owned: - fp.write(f'{indent} {self.get_name(x)}["{x.name}"]\n') - for subowner in subowners: - if subowner in self.graph: - self.dump_mermaid(fp, subowner, indent) - fp.write(f"{indent}end\n") diff --git a/transactron/lib/__init__.py b/transactron/lib/__init__.py deleted file mode 100644 index f6dd3ef0a..000000000 --- a/transactron/lib/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .fifo import * # noqa: F401 -from .connectors import * # noqa: F401 -from .buttons import * # noqa: F401 -from .adapters import * # noqa: F401 -from .transformers import * # noqa: F401 -from .reqres import * # noqa: F401 -from .storage import * # noqa: F401 -from .simultaneous import * # noqa: F401 -from .metrics import * # noqa: F401 diff --git a/transactron/lib/adapters.py b/transactron/lib/adapters.py deleted file mode 100644 index 81816b3c4..000000000 --- a/transactron/lib/adapters.py +++ /dev/null @@ -1,149 +0,0 @@ -from abc import abstractmethod -from typing import Optional -from amaranth import * -from amaranth.lib.wiring import Component, In, Out -from amaranth.lib.data import StructLayout, View - -from ..utils import SrcLoc, get_src_loc, MethodStruct -from ..core import * -from ..utils._typing import type_self_kwargs_as, SignalBundle - -__all__ = [ - "AdapterBase", - "AdapterTrans", - "Adapter", -] - - -class AdapterBase(Component): - data_in: MethodStruct - data_out: MethodStruct - en: Signal - done: Signal - - def __init__(self, iface: Method, layout_in: StructLayout, layout_out: StructLayout): - super().__init__({"data_in": In(layout_in), "data_out": Out(layout_out), "en": In(1), "done": Out(1)}) - self.iface = iface - - def debug_signals(self) -> SignalBundle: - return [self.en, self.done, self.data_in, self.data_out] - - @abstractmethod - def elaborate(self, platform) -> TModule: - raise NotImplementedError() - - -class AdapterTrans(AdapterBase): - """Adapter transaction. - - Creates a transaction controlled by plain Amaranth signals. Allows to - expose a method to plain Amaranth code, including testbenches. - - Attributes - ---------- - en: Signal, in - Activates the transaction (sets the `request` signal). - done: Signal, out - Signals that the transaction is performed (returns the `grant` - signal). - data_in: View, in - Data passed to the `iface` method. - data_out: View, out - Data returned from the `iface` method. - """ - - def __init__(self, iface: Method, *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - iface: Method - The method to be called by the transaction. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - super().__init__(iface, iface.layout_in, iface.layout_out) - self.src_loc = get_src_loc(src_loc) - - def elaborate(self, platform): - m = TModule() - - # this forces data_in signal to appear in VCD dumps - data_in = Signal.like(self.data_in) - m.d.comb += data_in.eq(self.data_in) - - with Transaction(name=f"AdapterTrans_{self.iface.name}", src_loc=self.src_loc).body(m, request=self.en): - data_out = self.iface(m, data_in) - m.d.top_comb += self.data_out.eq(data_out) - m.d.comb += self.done.eq(1) - - return m - - -class Adapter(AdapterBase): - """Adapter method. - - Creates a method controlled by plain Amaranth signals. One of the - possible uses is to mock a method in a testbench. - - Attributes - ---------- - en: Signal, in - Activates the method (sets the `ready` signal). - done: Signal, out - Signals that the method is called (returns the `run` signal). - data_in: View, in - Data returned from the defined method. - data_out: View, out - Data passed as argument to the defined method. - validators: list of tuples of View, out and Signal, in - Hooks for `validate_arguments`. - """ - - @type_self_kwargs_as(Method.__init__) - def __init__(self, **kwargs): - """ - Parameters - ---------- - **kwargs - Keyword arguments for Method that will be created. - See transactron.core.Method.__init__ for parameters description. - """ - - kwargs["src_loc"] = get_src_loc(kwargs.setdefault("src_loc", 0)) - - iface = Method(**kwargs) - super().__init__(iface, iface.layout_out, iface.layout_in) - self.validators: list[tuple[View[StructLayout], Signal]] = [] - self.with_validate_arguments: bool = False - - def set(self, with_validate_arguments: Optional[bool]): - if with_validate_arguments is not None: - self.with_validate_arguments = with_validate_arguments - return self - - def elaborate(self, platform): - m = TModule() - - # this forces data_in signal to appear in VCD dumps - data_in = Signal.like(self.data_in) - m.d.comb += data_in.eq(self.data_in) - - kwargs = {} - - if self.with_validate_arguments: - - def validate_arguments(arg: "View[StructLayout]"): - ret = Signal() - self.validators.append((arg, ret)) - return ret - - kwargs["validate_arguments"] = validate_arguments - - @def_method(m, self.iface, ready=self.en, **kwargs) - def _(arg): - m.d.top_comb += self.data_out.eq(arg) - m.d.comb += self.done.eq(1) - return data_in - - return m diff --git a/transactron/lib/buttons.py b/transactron/lib/buttons.py deleted file mode 100644 index d275cd25d..000000000 --- a/transactron/lib/buttons.py +++ /dev/null @@ -1,113 +0,0 @@ -from amaranth import * - -from transactron.utils.transactron_helpers import from_method_layout -from ..core import * -from ..utils import SrcLoc, get_src_loc, MethodLayout - -__all__ = ["ClickIn", "ClickOut"] - - -class ClickIn(Elaboratable): - """Clicked input. - - Useful for interactive simulations or FPGA button/switch interfaces. - On a rising edge (tested synchronously) of `btn`, the `get` method - is enabled, which returns the data present on `dat` at the time. - Inputs are synchronized. - - Attributes - ---------- - get: Method - The method for retrieving data from the input. Accepts an empty - argument, returns a structure. - btn: Signal, in - The button input. - dat: MethodStruct, in - The data input. - """ - - def __init__(self, layout: MethodLayout, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - layout: method layout - The data format for the input. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - src_loc = get_src_loc(src_loc) - self.get = Method(o=layout, src_loc=src_loc) - self.btn = Signal() - self.dat = Signal(from_method_layout(layout)) - - def elaborate(self, platform): - m = TModule() - - btn1 = Signal() - btn2 = Signal() - dat1 = Signal.like(self.dat) - m.d.sync += btn1.eq(self.btn) - m.d.sync += btn2.eq(btn1) - m.d.sync += dat1.eq(self.dat) - get_ready = Signal() - get_data = Signal.like(self.dat) - - @def_method(m, self.get, ready=get_ready) - def _(): - m.d.sync += get_ready.eq(0) - return get_data - - with m.If(~btn2 & btn1): - m.d.sync += get_ready.eq(1) - m.d.sync += get_data.eq(dat1) - - return m - - -class ClickOut(Elaboratable): - """Clicked output. - - Useful for interactive simulations or FPGA button/LED interfaces. - On a rising edge (tested synchronously) of `btn`, the `put` method - is enabled, which, when called, changes the value of the `dat` signal. - - Attributes - ---------- - put: Method - The method for retrieving data from the input. Accepts a structure, - returns empty result. - btn: Signal, in - The button input. - dat: MethodStruct, out - The data output. - """ - - def __init__(self, layout: MethodLayout, *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - layout: method layout - The data format for the output. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - src_loc = get_src_loc(src_loc) - self.put = Method(i=layout, src_loc=src_loc) - self.btn = Signal() - self.dat = Signal(from_method_layout(layout)) - - def elaborate(self, platform): - m = TModule() - - btn1 = Signal() - btn2 = Signal() - m.d.sync += btn1.eq(self.btn) - m.d.sync += btn2.eq(btn1) - - @def_method(m, self.put, ready=~btn2 & btn1) - def _(arg): - m.d.sync += self.dat.eq(arg) - - return m diff --git a/transactron/lib/connectors.py b/transactron/lib/connectors.py deleted file mode 100644 index 723660ff9..000000000 --- a/transactron/lib/connectors.py +++ /dev/null @@ -1,424 +0,0 @@ -from amaranth import * -from amaranth.lib.data import View -import amaranth.lib.fifo - -from transactron.utils.transactron_helpers import from_method_layout -from ..core import * -from ..utils import SrcLoc, get_src_loc, MethodLayout - -__all__ = [ - "FIFO", - "Forwarder", - "Connect", - "ConnectTrans", - "ManyToOneConnectTrans", - "StableSelectingNetwork", - "Pipe", -] - - -class FIFO(Elaboratable): - """FIFO module. - - Provides a transactional interface to Amaranth FIFOs. Exposes two methods: - `read`, and `write`. Both methods are ready only when they can - be executed -- i.e. the queue is respectively not empty / not full. - It is possible to simultaneously read and write in a single clock cycle, - but only if both readiness conditions are fulfilled. - - Attributes - ---------- - read: Method - The read method. Accepts an empty argument, returns a structure. - write: Method - The write method. Accepts a structure, returns empty result. - """ - - def __init__( - self, layout: MethodLayout, depth: int, fifo_type=amaranth.lib.fifo.SyncFIFO, *, src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - layout: method layout - The format of structures stored in the FIFO. - depth: int - Size of the FIFO. - fifoType: Elaboratable - FIFO module conforming to Amaranth library FIFO interface. Defaults - to SyncFIFO. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.layout = from_method_layout(layout) - self.width = self.layout.size - self.depth = depth - self.fifoType = fifo_type - - src_loc = get_src_loc(src_loc) - self.read = Method(o=layout, src_loc=src_loc) - self.write = Method(i=layout, src_loc=src_loc) - - def elaborate(self, platform): - m = TModule() - - m.submodules.fifo = fifo = self.fifoType(width=self.width, depth=self.depth) - - @def_method(m, self.write, ready=fifo.w_rdy) - def _(arg): - m.d.comb += fifo.w_en.eq(1) - m.d.top_comb += fifo.w_data.eq(arg) - - @def_method(m, self.read, ready=fifo.r_rdy) - def _(): - m.d.comb += fifo.r_en.eq(1) - return View(self.layout, fifo.r_data) # remove View after Amaranth upgrade - - return m - - -# Forwarding with overflow buffering - - -class Forwarder(Elaboratable): - """Forwarding with overflow buffering - - Provides a means to connect two transactions with forwarding. Exposes - two methods: `read`, and `write`. When both of these methods are - executed simultaneously, data is forwarded between them. If `write` - is executed, but `read` is not, the value cannot be forwarded, - but is stored into an overflow buffer. No further `write`\\s are - possible until the overflow buffer is cleared by `read`. - - The `write` method is scheduled before `read`. - - Attributes - ---------- - read: Method - The read method. Accepts an empty argument, returns a structure. - write: Method - The write method. Accepts a structure, returns empty result. - """ - - def __init__(self, layout: MethodLayout, *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - layout: method layout - The format of structures forwarded. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - src_loc = get_src_loc(src_loc) - self.read = Method(o=layout, src_loc=src_loc) - self.write = Method(i=layout, src_loc=src_loc) - self.clear = Method(src_loc=src_loc) - self.head = Signal.like(self.read.data_out) - - self.clear.add_conflict(self.read, Priority.LEFT) - self.clear.add_conflict(self.write, Priority.LEFT) - - def elaborate(self, platform): - m = TModule() - - reg = Signal.like(self.read.data_out) - reg_valid = Signal() - read_value = Signal.like(self.read.data_out) - m.d.comb += self.head.eq(read_value) - - self.write.schedule_before(self.read) # to avoid combinational loops - - @def_method(m, self.write, ready=~reg_valid) - def _(arg): - m.d.av_comb += read_value.eq(arg) # for forwarding - m.d.sync += reg.eq(arg) - m.d.sync += reg_valid.eq(1) - - with m.If(reg_valid): - m.d.av_comb += read_value.eq(reg) # write method is not ready - - @def_method(m, self.read, ready=reg_valid | self.write.run) - def _(): - m.d.sync += reg_valid.eq(0) - return read_value - - @def_method(m, self.clear) - def _(): - m.d.sync += reg_valid.eq(0) - - return m - - -class Pipe(Elaboratable): - """ - This module implements a `Pipe`. It is a halfway between - `Forwarder` and `2-FIFO`. In the `Pipe` data is always - stored localy, so the critical path of the data is cut, but there is a - combinational path between the control signals of the `read` and - the `write` methods. For comparison: - - in `Forwarder` there is both a data and a control combinational path - - in `2-FIFO` there are no combinational paths - - The `read` method is scheduled before the `write`. - - Attributes - ---------- - read: Method - Reads from the pipe. Accepts an empty argument, returns a structure. - Ready only if the pipe is not empty. - write: Method - Writes to the pipe. Accepts a structure, returns empty result. - Ready only if the pipe is not full. - clean: Method - Cleans the pipe. Has priority over `read` and `write` methods. - """ - - def __init__(self, layout: MethodLayout): - """ - Parameters - ---------- - layout: record layout - The format of records forwarded. - """ - self.read = Method(o=layout) - self.write = Method(i=layout) - self.clean = Method() - self.head = Signal.like(self.read.data_out) - - self.clean.add_conflict(self.read, Priority.LEFT) - self.clean.add_conflict(self.write, Priority.LEFT) - - def elaborate(self, platform): - m = TModule() - - reg = Signal.like(self.read.data_out) - reg_valid = Signal() - - self.read.schedule_before(self.write) # to avoid combinational loops - - @def_method(m, self.read, ready=reg_valid) - def _(): - m.d.sync += reg_valid.eq(0) - return reg - - @def_method(m, self.write, ready=~reg_valid | self.read.run) - def _(arg): - m.d.sync += reg.eq(arg) - m.d.sync += reg_valid.eq(1) - - @def_method(m, self.clean) - def _(): - m.d.sync += reg_valid.eq(0) - - return m - - -class Connect(Elaboratable): - """Forwarding by transaction simultaneity - - Provides a means to connect two transactions with forwarding - by means of the transaction simultaneity mechanism. It provides - two methods: `read`, and `write`, which always execute simultaneously. - Typical use case is for moving data from `write` to `read`, but - data flow in the reverse direction is also possible. - - Attributes - ---------- - read: Method - The read method. Accepts a (possibly empty) structure, returns - a structure. - write: Method - The write method. Accepts a structure, returns a (possibly empty) - structure. - """ - - def __init__(self, layout: MethodLayout = (), rev_layout: MethodLayout = (), *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - layout: method layout - The format of structures forwarded. - rev_layout: method layout - The format of structures forwarded in the reverse direction. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - src_loc = get_src_loc(src_loc) - self.read = Method(o=layout, i=rev_layout, src_loc=src_loc) - self.write = Method(i=layout, o=rev_layout, src_loc=src_loc) - - def elaborate(self, platform): - m = TModule() - - read_value = Signal.like(self.read.data_out) - rev_read_value = Signal.like(self.write.data_out) - - self.write.simultaneous(self.read) - - @def_method(m, self.write) - def _(arg): - m.d.av_comb += read_value.eq(arg) - return rev_read_value - - @def_method(m, self.read) - def _(arg): - m.d.av_comb += rev_read_value.eq(arg) - return read_value - - return m - - -class ConnectTrans(Elaboratable): - """Simple connecting transaction. - - Takes two methods and creates a transaction which calls both of them. - Result of the first method is connected to the argument of the second, - and vice versa. Allows easily connecting methods with compatible - layouts. - """ - - def __init__(self, method1: Method, method2: Method, *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - method1: Method - First method. - method2: Method - Second method. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.method1 = method1 - self.method2 = method2 - self.src_loc = get_src_loc(src_loc) - - def elaborate(self, platform): - m = TModule() - - with Transaction(src_loc=self.src_loc).body(m): - data1 = Signal.like(self.method1.data_out) - data2 = Signal.like(self.method2.data_out) - - m.d.top_comb += data1.eq(self.method1(m, data2)) - m.d.top_comb += data2.eq(self.method2(m, data1)) - - return m - - -class ManyToOneConnectTrans(Elaboratable): - """Many-to-one method connection. - - Connects each of a set of methods to another method using separate - transactions. Equivalent to a set of `ConnectTrans`. - """ - - def __init__(self, *, get_results: list[Method], put_result: Method, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - get_results: list[Method] - Methods to be connected to the `put_result` method. - put_result: Method - Common method for each of the connections created. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.get_results = get_results - self.m_put_result = put_result - - self.count = len(self.get_results) - self.src_loc = get_src_loc(src_loc) - - def elaborate(self, platform): - m = TModule() - - for i in range(self.count): - m.submodules[f"ManyToOneConnectTrans_input_{i}"] = ConnectTrans( - self.m_put_result, self.get_results[i], src_loc=self.src_loc - ) - - return m - - -class StableSelectingNetwork(Elaboratable): - """A network that groups inputs with a valid bit set. - - The circuit takes `n` inputs with a valid signal each and - on the output returns a grouped and consecutive sequence of the provided - input signals. The order of valid inputs is preserved. - - For example for input (0 is an invalid input): - 0, a, 0, d, 0, 0, e - - The circuit will return: - a, d, e, 0, 0, 0, 0 - - The circuit uses a divide and conquer algorithm. - The recursive call takes two bit vectors and each of them - is already properly sorted, for example: - v1 = [a, b, 0, 0]; v2 = [c, d, e, 0] - - Now by shifting left v2 and merging it with v1, we get the result: - v = [a, b, c, d, e, 0, 0, 0] - - Thus, the network has depth log_2(n). - - """ - - def __init__(self, n: int, layout: MethodLayout): - self.n = n - self.layout = from_method_layout(layout) - - self.inputs = [Signal(self.layout) for _ in range(n)] - self.valids = [Signal() for _ in range(n)] - - self.outputs = [Signal(self.layout) for _ in range(n)] - self.output_cnt = Signal(range(n + 1)) - - def elaborate(self, platform): - m = TModule() - - current_level = [] - for i in range(self.n): - current_level.append((Array([self.inputs[i]]), self.valids[i])) - - # Create the network using the bottom-up approach. - while len(current_level) >= 2: - next_level = [] - while len(current_level) >= 2: - a, cnt_a = current_level.pop(0) - b, cnt_b = current_level.pop(0) - - total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1) - m.d.comb += total_cnt.eq(cnt_a + cnt_b) - - total_len = len(a) + len(b) - merged = Array(Signal(self.layout) for _ in range(total_len)) - - for i in range(len(a)): - m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i])) - for i in range(len(b)): - m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a])) - - next_level.append((merged, total_cnt)) - - # If we had an odd number of elements on the current level, - # move the item left to the next level. - if len(current_level) == 1: - next_level.append(current_level.pop(0)) - - current_level = next_level - - last_level, total_cnt = current_level.pop(0) - - for i in range(self.n): - m.d.comb += self.outputs[i].eq(last_level[i]) - - m.d.comb += self.output_cnt.eq(total_cnt) - - return m diff --git a/transactron/lib/dependencies.py b/transactron/lib/dependencies.py deleted file mode 100644 index c7b099b76..000000000 --- a/transactron/lib/dependencies.py +++ /dev/null @@ -1,34 +0,0 @@ -from collections.abc import Callable - -from .. import Method -from .transformers import Unifier -from ..utils.dependencies import * - - -__all__ = ["DependencyManager", "DependencyKey", "SimpleKey", "ListKey", "UnifierKey"] - - -class UnifierKey(DependencyKey["Method", tuple["Method", dict[str, "Unifier"]]]): - """Base class for method unifier dependency keys. - - Method unifier dependency keys are used to collect methods to be called by - some part of the core. As multiple modules may wish to be called, a method - unifier is used to present a single method interface to the caller, which - allows to customize the calling behavior. - """ - - unifier: Callable[[list["Method"]], "Unifier"] - - def __init_subclass__(cls, unifier: Callable[[list["Method"]], "Unifier"], **kwargs) -> None: - cls.unifier = unifier - return super().__init_subclass__(**kwargs) - - def combine(self, data: list["Method"]) -> tuple["Method", dict[str, "Unifier"]]: - if len(data) == 1: - return data[0], {} - else: - unifiers: dict[str, Unifier] = {} - unifier_inst = self.unifier(data) - unifiers[self.__class__.__name__ + "_unifier"] = unifier_inst - method = unifier_inst.method - return method, unifiers diff --git a/transactron/lib/fifo.py b/transactron/lib/fifo.py deleted file mode 100644 index f9d43c30f..000000000 --- a/transactron/lib/fifo.py +++ /dev/null @@ -1,165 +0,0 @@ -from amaranth import * -import amaranth.lib.memory as memory -from transactron import Method, def_method, Priority, TModule -from transactron.utils._typing import ValueLike, MethodLayout, SrcLoc, MethodStruct -from transactron.utils.amaranth_ext import mod_incr -from transactron.utils.transactron_helpers import from_method_layout, get_src_loc - - -class BasicFifo(Elaboratable): - """Transactional FIFO queue - - Attributes - ---------- - read: Method - Reads from the FIFO. Accepts an empty argument, returns a structure. - Ready only if the FIFO is not empty. - peek: Method - Returns the element at the front (but not delete). Ready only if the FIFO - is not empty. The method is nonexclusive. - write: Method - Writes to the FIFO. Accepts a structure, returns empty result. - Ready only if the FIFO is not full. - clear: Method - Clears the FIFO entries. Has priority over `read` and `write` methods. - Note that, clearing the FIFO doesn't reinitialize it to values passed in `init` parameter. - - """ - - def __init__(self, layout: MethodLayout, depth: int, *, src_loc: int | SrcLoc = 0) -> None: - """ - Parameters - ---------- - layout: method layout - Layout of data stored in the FIFO. - depth: int - Size of the FIFO. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.layout = layout - self.width = from_method_layout(self.layout).size - self.depth = depth - - src_loc = get_src_loc(src_loc) - self.read = Method(o=self.layout, src_loc=src_loc) - self.peek = Method(o=self.layout, nonexclusive=True, src_loc=src_loc) - self.write = Method(i=self.layout, src_loc=src_loc) - self.clear = Method(src_loc=src_loc) - self.head = Signal(from_method_layout(layout)) - - self.buff = memory.Memory(shape=self.width, depth=self.depth, init=[]) - - self.write_ready = Signal() - self.read_ready = Signal() - - self.read_idx = Signal((self.depth - 1).bit_length()) - self.write_idx = Signal((self.depth - 1).bit_length()) - # current fifo depth - self.level = Signal((self.depth).bit_length()) - - # for interface compatibility with MultiportFifo - self.read_methods = [self.read] - self.write_methods = [self.write] - - def elaborate(self, platform): - m = TModule() - - next_read_idx = Signal.like(self.read_idx) - m.d.comb += next_read_idx.eq(mod_incr(self.read_idx, self.depth)) - - m.submodules.buff = self.buff - self.buff_wrport = self.buff.write_port() - self.buff_rdport = self.buff.read_port(domain="sync", transparent_for=[self.buff_wrport]) - - m.d.comb += self.read_ready.eq(self.level != 0) - m.d.comb += self.write_ready.eq(self.level != self.depth) - - with m.If(self.read.run & ~self.write.run): - m.d.sync += self.level.eq(self.level - 1) - with m.If(self.write.run & ~self.read.run): - m.d.sync += self.level.eq(self.level + 1) - with m.If(self.clear.run): - m.d.sync += self.level.eq(0) - - m.d.comb += self.buff_rdport.addr.eq(Mux(self.read.run, next_read_idx, self.read_idx)) - m.d.comb += self.head.eq(self.buff_rdport.data) - - @def_method(m, self.write, ready=self.write_ready) - def _(arg: MethodStruct) -> None: - m.d.top_comb += self.buff_wrport.addr.eq(self.write_idx) - m.d.top_comb += self.buff_wrport.data.eq(arg) - m.d.comb += self.buff_wrport.en.eq(1) - - m.d.sync += self.write_idx.eq(mod_incr(self.write_idx, self.depth)) - - @def_method(m, self.read, self.read_ready) - def _() -> ValueLike: - m.d.sync += self.read_idx.eq(next_read_idx) - return self.head - - @def_method(m, self.peek, self.read_ready) - def _() -> ValueLike: - return self.head - - @def_method(m, self.clear) - def _() -> None: - m.d.sync += self.read_idx.eq(0) - m.d.sync += self.write_idx.eq(0) - - return m - - -class Semaphore(Elaboratable): - """Semaphore""" - - def __init__(self, max_count: int) -> None: - """ - Parameters - ---------- - size: int - Size of the semaphore. - - """ - self.max_count = max_count - - self.acquire = Method() - self.release = Method() - self.clear = Method() - - self.acquire_ready = Signal() - self.release_ready = Signal() - - self.count = Signal(self.max_count.bit_length()) - self.count_next = Signal(self.max_count.bit_length()) - - self.clear.add_conflict(self.acquire, Priority.LEFT) - self.clear.add_conflict(self.release, Priority.LEFT) - - def elaborate(self, platform) -> TModule: - m = TModule() - - m.d.comb += self.release_ready.eq(self.count > 0) - m.d.comb += self.acquire_ready.eq(self.count < self.max_count) - - with m.If(self.clear.run): - m.d.comb += self.count_next.eq(0) - with m.Else(): - m.d.comb += self.count_next.eq(self.count + self.acquire.run - self.release.run) - - m.d.sync += self.count.eq(self.count_next) - - @def_method(m, self.acquire, ready=self.acquire_ready) - def _() -> None: - pass - - @def_method(m, self.release, ready=self.release_ready) - def _() -> None: - pass - - @def_method(m, self.clear) - def _() -> None: - pass - - return m diff --git a/transactron/lib/logging.py b/transactron/lib/logging.py deleted file mode 100644 index 7eb06deb1..000000000 --- a/transactron/lib/logging.py +++ /dev/null @@ -1,229 +0,0 @@ -import os -import re -import operator -import logging -from functools import reduce -from dataclasses import dataclass, field -from dataclasses_json import dataclass_json -from typing import TypeAlias - -from amaranth import * -from amaranth.tracer import get_src_loc - -from transactron.utils import SrcLoc -from transactron.utils._typing import ModuleLike, ValueLike -from transactron.utils.dependencies import DependencyContext, ListKey - -LogLevel: TypeAlias = int - - -@dataclass_json -@dataclass -class LogRecordInfo: - """Simulator-backend-agnostic information about a log record that can - be serialized and used outside the Amaranth context. - - Attributes - ---------- - logger_name: str - - level: LogLevel - The severity level of the log. - format_str: str - The template of the message. Should follow PEP 3101 standard. - location: SrcLoc - Source location of the log. - """ - - logger_name: str - level: LogLevel - format_str: str - location: SrcLoc - - def format(self, *args) -> str: - """Format the log message with a set of concrete arguments.""" - - return self.format_str.format(*args) - - -@dataclass -class LogRecord(LogRecordInfo): - """A LogRecord instance represents an event being logged. - - Attributes - ---------- - trigger: Signal - Amaranth signal triggering the log. - fields: Signal - Amaranth signals that will be used to format the message. - """ - - trigger: Signal - fields: list[Signal] = field(default_factory=list) - - -@dataclass(frozen=True) -class LogKey(ListKey[LogRecord]): - pass - - -class HardwareLogger: - """A class for creating log messages in the hardware. - - Intuitively, the hardware logger works similarly to a normal software - logger. You can log a message anywhere in the circuit, but due to the - parallel nature of the hardware you must specify a special trigger signal - which will indicate if a message shall be reported in that cycle. - - Hardware logs are evaluated and printed during simulation, so both - the trigger and the format fields are Amaranth values, i.e. - signals or arbitrary Amaranth expressions. - - Instances of the HardwareLogger class represent a logger for a single - submodule of the circuit. Exactly how a "submodule" is defined is up - to the developer. Submodule are identified by a unique string and - the names can be nested. Names are organized into a namespace hierarchy - where levels are separated by periods, much like the Python package - namespace. So in the instance, submodules names might be "frontend" - for the upper level, and "frontend.icache" and "frontend.bpu" for - the sub-levels. There is no arbitrary limit to the depth of nesting. - - Attributes - ---------- - name: str - Name of this logger. - """ - - def __init__(self, name: str): - """ - Parameters - ---------- - name: str - Name of this logger. Hierarchy levels are separated by periods, - e.g. "backend.fu.jumpbranch". - """ - self.name = name - - def log(self, m: ModuleLike, level: LogLevel, trigger: ValueLike, format: str, *args, src_loc_at: int = 0): - """Registers a hardware log record with the given severity. - - Parameters - ---------- - m: ModuleLike - The module for which the log record is added. - trigger: ValueLike - If the value of this Amaranth expression is true, the log will reported. - format: str - The format of the message as defined in PEP 3101. - *args - Amaranth values that will be read during simulation and used to format - the message. - src_loc_at: int, optional - How many stack frames below to look for the source location, used to - identify the failing assertion. - """ - - def local_src_loc(src_loc: SrcLoc): - return (os.path.relpath(src_loc[0]), src_loc[1]) - - src_loc = local_src_loc(get_src_loc(src_loc_at + 1)) - - trigger_signal = Signal() - m.d.comb += trigger_signal.eq(trigger) - - record = LogRecord( - logger_name=self.name, level=level, format_str=format, location=src_loc, trigger=trigger_signal - ) - - for arg in args: - sig = Signal.like(arg) - m.d.top_comb += sig.eq(arg) - record.fields.append(sig) - - dependencies = DependencyContext.get() - dependencies.add_dependency(LogKey(), record) - - def debug(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): - """Log a message with severity 'DEBUG'. - - See `HardwareLogger.log` function for more details. - """ - self.log(m, logging.DEBUG, trigger, format, *args, **kwargs) - - def info(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): - """Log a message with severity 'INFO'. - - See `HardwareLogger.log` function for more details. - """ - self.log(m, logging.INFO, trigger, format, *args, **kwargs) - - def warning(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): - """Log a message with severity 'WARNING'. - - See `HardwareLogger.log` function for more details. - """ - self.log(m, logging.WARNING, trigger, format, *args, **kwargs) - - def error(self, m: ModuleLike, trigger: ValueLike, format: str, *args, **kwargs): - """Log a message with severity 'ERROR'. - - This severity level has special semantics. If a log with this serverity - level is triggered, the simulation will be terminated. - - See `HardwareLogger.log` function for more details. - """ - self.log(m, logging.ERROR, trigger, format, *args, **kwargs) - - def assertion(self, m: ModuleLike, value: Value, format: str = "", *args, src_loc_at: int = 0, **kwargs): - """Define an assertion. - - This function might help find some hardware bugs which might otherwise be - hard to detect. If `value` is false, it will terminate the simulation or - it can also be used to turn on a warning LED on a board. - - Internally, this is a convenience wrapper over log.error. - - See `HardwareLogger.log` function for more details. - """ - self.error(m, ~value, format, *args, **kwargs, src_loc_at=src_loc_at + 1) - - -def get_log_records(level: LogLevel, namespace_regexp: str = ".*") -> list[LogRecord]: - """Get log records in for the given severity level and in the - specified namespace. - - This function returns all log records with the severity bigger or equal - to the specified level and belonging to the specified namespace. - - Parameters - ---------- - level: LogLevel - The minimum severity level. - namespace: str, optional - The regexp of the namespace. If not specified, logs from all namespaces - will be processed. - """ - - dependencies = DependencyContext.get() - all_logs = dependencies.get_dependency(LogKey()) - return [rec for rec in all_logs if rec.level >= level and re.search(namespace_regexp, rec.logger_name)] - - -def get_trigger_bit(level: LogLevel, namespace_regexp: str = ".*") -> Value: - """Get a trigger bit for logs of the given severity level and - in the specified namespace. - - The signal returned by this function is high whenever the trigger signal - of any of the records with the severity bigger or equal to the specified - level is high. - - Parameters - ---------- - level: LogLevel - The minimum severity level. - namespace: str, optional - The regexp of the namespace. If not specified, logs from all namespaces - will be processed. - """ - - return reduce(operator.or_, [rec.trigger for rec in get_log_records(level, namespace_regexp)], C(0)) diff --git a/transactron/lib/metrics.py b/transactron/lib/metrics.py deleted file mode 100644 index 78f5c5e53..000000000 --- a/transactron/lib/metrics.py +++ /dev/null @@ -1,822 +0,0 @@ -from dataclasses import dataclass, field -from dataclasses_json import dataclass_json -from typing import Optional, Type -from abc import ABC -from enum import Enum - -from amaranth import * -from amaranth.utils import bits_for, ceil_log2, exact_log2 - -from transactron.utils import ValueLike, OneHotSwitchDynamic, SignalBundle -from transactron import Method, def_method, TModule -from transactron.lib import FIFO, AsyncMemoryBank, logging -from transactron.utils.dependencies import ListKey, DependencyContext, SimpleKey - -__all__ = [ - "MetricRegisterModel", - "MetricModel", - "HwMetric", - "HwCounter", - "TaggedCounter", - "HwExpHistogram", - "FIFOLatencyMeasurer", - "TaggedLatencyMeasurer", - "HardwareMetricsManager", - "HwMetricsEnabledKey", -] - - -@dataclass_json -@dataclass(frozen=True) -class MetricRegisterModel: - """ - Represents a single register of a metric, serving as a fundamental - building block that holds a singular value. - - Attributes - ---------- - name: str - The unique identifier for the register (among remaning - registers of a specific metric). - description: str - A brief description of the metric's purpose. - width: int - The bit-width of the register. - """ - - name: str - description: str - width: int - - -@dataclass_json -@dataclass -class MetricModel: - """ - Provides information about a metric exposed by the circuit. Each metric - comprises multiple registers, each dedicated to storing specific values. - - The configuration of registers is internally determined by a - specific metric type and is not user-configurable. - - Attributes - ---------- - fully_qualified_name: str - The fully qualified name of the metric, with name components joined by dots ('.'), - e.g., 'foo.bar.requests'. - description: str - A human-readable description of the metric's functionality. - regs: list[MetricRegisterModel] - A list of registers associated with the metric. - """ - - fully_qualified_name: str - description: str - regs: dict[str, MetricRegisterModel] = field(default_factory=dict) - - -class HwMetricRegister(MetricRegisterModel): - """ - A concrete implementation of a metric register that holds its value as Amaranth signal. - - Attributes - ---------- - value: Signal - Amaranth signal representing the value of the register. - """ - - def __init__(self, name: str, width_bits: int, description: str = "", init: int = 0): - """ - Parameters - ---------- - name: str - The unique identifier for the register (among remaning - registers of a specific metric). - width: int - The bit-width of the register. - description: str - A brief description of the metric's purpose. - init: int - The reset value of the register. - """ - super().__init__(name, description, width_bits) - - self.value = Signal(width_bits, init=init, name=name) - - -@dataclass(frozen=True) -class HwMetricsListKey(ListKey["HwMetric"]): - """DependencyManager key collecting hardware metrics globally as a list.""" - - pass - - -@dataclass(frozen=True) -class HwMetricsEnabledKey(SimpleKey[bool]): - """ - DependencyManager key for enabling hardware metrics. If metrics are disabled, - none of theirs signals will be synthesized. - """ - - lock_on_get = False - empty_valid = True - default_value = False - - -class HwMetric(ABC, MetricModel): - """ - A base for all metric implementations. It should be only used for declaring - new types of metrics. - - It takes care of registering the metric in the dependency manager. - - Attributes - ---------- - signals: dict[str, Signal] - A mapping from a register name to a Signal containing the value of that register. - """ - - def __init__(self, fully_qualified_name: str, description: str): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - """ - super().__init__(fully_qualified_name, description) - - self.signals: dict[str, Signal] = {} - - # add the metric to the global list of all metrics - DependencyContext.get().add_dependency(HwMetricsListKey(), self) - - # So Amaranth doesn't report that the module is unused when metrics are disabled - self._MustUse__silence = True # type: ignore - - def add_registers(self, regs: list[HwMetricRegister]): - """ - Adds registers to a metric. Should be only called by inheriting classes - during initialization. - - Parameters - ---------- - regs: list[HwMetricRegister] - A list of registers to be registered. - """ - for reg in regs: - if reg.name in self.regs: - raise RuntimeError(f"Register {reg.name}' is already added to the metric {self.fully_qualified_name}") - - self.regs[reg.name] = reg - self.signals[reg.name] = reg.value - - def metrics_enabled(self) -> bool: - return DependencyContext.get().get_dependency(HwMetricsEnabledKey()) - - # To restore hashability lost by dataclass subclassing - def __hash__(self): - return object.__hash__(self) - - -class HwCounter(Elaboratable, HwMetric): - """Hardware Counter - - The most basic hardware metric that can just increase its value. - """ - - def __init__(self, fully_qualified_name: str, description: str = "", *, width_bits: int = 32): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - width_bits: int - The bit-width of the register. Defaults to 32 bits. - """ - - super().__init__(fully_qualified_name, description) - - self.count = HwMetricRegister("count", width_bits, "the value of the counter") - - self.add_registers([self.count]) - - self._incr = Method() - - def elaborate(self, platform): - if not self.metrics_enabled(): - return TModule() - - m = TModule() - - @def_method(m, self._incr) - def _(): - m.d.sync += self.count.value.eq(self.count.value + 1) - - return m - - def incr(self, m: TModule, *, cond: ValueLike = C(1)): - """ - Increases the value of the counter by 1. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - """ - if not self.metrics_enabled(): - return - - with m.If(cond): - self._incr(m) - - -class TaggedCounter(Elaboratable, HwMetric): - """Hardware Tagged Counter - - Like HwCounter, but contains multiple counters, each with its own tag. - At a time a single counter can be increased and the value of the tag - can be provided dynamically. The type of the tag can be either an int - enum, a range or a list of integers (negative numbers are ok). - - Internally, it detects if tag values can be one-hot encoded and if so, - it generates more optimized circuit. - - Attributes - ---------- - tag_width: int - The length of the signal holding a tag value. - one_hot: bool - Whether tag values can be one-hot encoded. - counters: dict[int, HwMetricRegisters] - Mapping from a tag value to a register holding a counter for that tag. - """ - - def __init__( - self, - fully_qualified_name: str, - description: str = "", - *, - tags: range | Type[Enum] | list[int], - registers_width: int = 32, - ): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - tags: range | Type[Enum] | list[int] - Tag values. - registers_width: int - Width of the underlying registers. Defaults to 32 bits. - """ - - super().__init__(fully_qualified_name, description) - - if isinstance(tags, range) or isinstance(tags, list): - counters_meta = [(i, f"{i}") for i in tags] - else: - counters_meta = [(i.value, i.name) for i in tags] - - values = [value for value, _ in counters_meta] - self.tag_width = max(bits_for(max(values)), bits_for(min(values))) - - self.one_hot = True - negative_values = False - for value in values: - if value < 0: - self.one_hot = False - negative_values = True - break - - log = ceil_log2(value) - if 2**log != value: - self.one_hot = False - - self._incr = Method(i=[("tag", Shape(self.tag_width, signed=negative_values))]) - - self.counters: dict[int, HwMetricRegister] = {} - for tag_value, name in counters_meta: - value_str = ("1<<" + str(exact_log2(tag_value))) if self.one_hot else str(tag_value) - description = f"the counter for tag {name} (value={value_str})" - - self.counters[tag_value] = HwMetricRegister( - name, - registers_width, - description, - ) - - self.add_registers(list(self.counters.values())) - - def elaborate(self, platform): - if not self.metrics_enabled(): - return TModule() - - m = TModule() - - @def_method(m, self._incr) - def _(tag): - if self.one_hot: - sorted_tags = sorted(list(self.counters.keys())) - for i in OneHotSwitchDynamic(m, tag): - counter = self.counters[sorted_tags[i]] - m.d.sync += counter.value.eq(counter.value + 1) - else: - for tag_value, counter in self.counters.items(): - with m.If(tag == tag_value): - m.d.sync += counter.value.eq(counter.value + 1) - - return m - - def incr(self, m: TModule, tag: ValueLike, *, cond: ValueLike = C(1)): - """ - Increases the counter of a given tag by 1. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - tag: ValueLike - The tag of the counter. - cond: ValueLike - When set to high, the counter will be increased. By default set to high. - """ - if not self.metrics_enabled(): - return - - with m.If(cond): - self._incr(m, tag) - - -class HwExpHistogram(Elaboratable, HwMetric): - """Hardware Exponential Histogram - - Represents the distribution of sampled data through a histogram. A histogram - samples observations (usually things like request durations or queue sizes) and counts - them in a configurable number of buckets. The buckets are of exponential size. For example, - a histogram with 5 buckets would have the following value ranges: - [0, 1); [1, 2); [2, 4); [4, 8); [8, +inf). - - Additionally, the histogram tracks the number of observations, the sum - of observed values, and the minimum and maximum values. - """ - - def __init__( - self, - fully_qualified_name: str, - description: str = "", - *, - bucket_count: int, - sample_width: int = 32, - registers_width: int = 32, - ): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - max_value: int - The maximum value that the histogram would be able to count. This - value is used to calculate the number of buckets. - """ - - super().__init__(fully_qualified_name, description) - self.bucket_count = bucket_count - self.sample_width = sample_width - - self._add = Method(i=[("sample", self.sample_width)]) - - self.count = HwMetricRegister("count", registers_width, "the count of events that have been observed") - self.sum = HwMetricRegister("sum", registers_width, "the total sum of all observed values") - self.min = HwMetricRegister( - "min", - self.sample_width, - "the minimum of all observed values", - init=(1 << self.sample_width) - 1, - ) - self.max = HwMetricRegister("max", self.sample_width, "the maximum of all observed values") - - self.buckets = [] - for i in range(self.bucket_count): - bucket_start = 0 if i == 0 else 2 ** (i - 1) - bucket_end = "inf" if i == self.bucket_count - 1 else 2**i - - self.buckets.append( - HwMetricRegister( - f"bucket-{bucket_end}", - registers_width, - f"the cumulative counter for the observation bucket [{bucket_start}, {bucket_end})", - ) - ) - - self.add_registers([self.count, self.sum, self.max, self.min] + self.buckets) - - def elaborate(self, platform): - if not self.metrics_enabled(): - return TModule() - - m = TModule() - - @def_method(m, self._add) - def _(sample): - m.d.sync += self.count.value.eq(self.count.value + 1) - m.d.sync += self.sum.value.eq(self.sum.value + sample) - - with m.If(sample > self.max.value): - m.d.sync += self.max.value.eq(sample) - - with m.If(sample < self.min.value): - m.d.sync += self.min.value.eq(sample) - - # todo: perhaps replace with a recursive implementation of the priority encoder - bucket_idx = Signal(range(self.sample_width)) - for i in range(self.sample_width): - with m.If(sample[i]): - m.d.av_comb += bucket_idx.eq(i) - - for i, bucket in enumerate(self.buckets): - should_incr = C(0) - if i == 0: - # The first bucket has a range [0, 1). - should_incr = sample == 0 - elif i == self.bucket_count - 1: - # The last bucket should count values bigger or equal to 2**(self.bucket_count-1) - should_incr = (bucket_idx >= i - 1) & (sample != 0) - else: - should_incr = (bucket_idx == i - 1) & (sample != 0) - - with m.If(should_incr): - m.d.sync += bucket.value.eq(bucket.value + 1) - - return m - - def add(self, m: TModule, sample: Value): - """ - Adds a new sample to the histogram. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - sample: ValueLike - The value that will be added to the histogram - """ - - if not self.metrics_enabled(): - return - - self._add(m, sample) - - -class FIFOLatencyMeasurer(Elaboratable): - """ - Measures duration between two events, e.g. request processing latency. - It can track multiple events at the same time, i.e. the second event can - be registered as started, before the first finishes. However, they must be - processed in the FIFO order. - - The module exposes an exponential histogram of the measured latencies. - """ - - def __init__( - self, - fully_qualified_name: str, - description: str = "", - *, - slots_number: int, - max_latency: int, - ): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - slots_number: int - A number of events that the module can track simultaneously. - max_latency: int - The maximum latency of an event. Used to set signal widths and - number of buckets in the histogram. If a latency turns to be - bigger than the maximum, it will overflow and result in a false - measurement. - """ - self.fully_qualified_name = fully_qualified_name - self.description = description - self.slots_number = slots_number - self.max_latency = max_latency - - self._start = Method() - self._stop = Method() - - # This bucket count gives us the best possible granularity. - bucket_count = bits_for(self.max_latency) + 1 - self.histogram = HwExpHistogram( - self.fully_qualified_name, - self.description, - bucket_count=bucket_count, - sample_width=bits_for(self.max_latency), - ) - - def elaborate(self, platform): - if not self.metrics_enabled(): - return TModule() - - m = TModule() - - epoch_width = bits_for(self.max_latency) - - m.submodules.fifo = self.fifo = FIFO([("epoch", epoch_width)], self.slots_number) - m.submodules.histogram = self.histogram - - epoch = Signal(epoch_width) - - m.d.sync += epoch.eq(epoch + 1) - - @def_method(m, self._start) - def _(): - self.fifo.write(m, epoch) - - @def_method(m, self._stop) - def _(): - ret = self.fifo.read(m) - # The result of substracting two unsigned n-bit is a signed (n+1)-bit value, - # so we need to cast the result and discard the most significant bit. - duration = (epoch - ret.epoch).as_unsigned()[:-1] - self.histogram.add(m, duration) - - return m - - def start(self, m: TModule): - """ - Registers the start of an event. Can be called before the previous events - finish. If there are no slots available, the method will be blocked. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - """ - - if not self.metrics_enabled(): - return - - self._start(m) - - def stop(self, m: TModule): - """ - Registers the end of the oldest event (the FIFO order). If there are no - started events in the queue, the method will block. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - """ - - if not self.metrics_enabled(): - return - - self._stop(m) - - def metrics_enabled(self) -> bool: - return DependencyContext.get().get_dependency(HwMetricsEnabledKey()) - - -class TaggedLatencyMeasurer(Elaboratable): - """ - Measures duration between two events, e.g. request processing latency. - It can track multiple events at the same time, i.e. the second event can - be registered as started, before the first finishes. However, each event - needs to have an unique slot tag. - - The module exposes an exponential histogram of the measured latencies. - """ - - def __init__( - self, - fully_qualified_name: str, - description: str = "", - *, - slots_number: int, - max_latency: int, - ): - """ - Parameters - ---------- - fully_qualified_name: str - The fully qualified name of the metric. - description: str - A human-readable description of the metric's functionality. - slots_number: int - A number of events that the module can track simultaneously. - max_latency: int - The maximum latency of an event. Used to set signal widths and - number of buckets in the histogram. If a latency turns to be - bigger than the maximum, it will overflow and result in a false - measurement. - """ - self.fully_qualified_name = fully_qualified_name - self.description = description - self.slots_number = slots_number - self.max_latency = max_latency - - self._start = Method(i=[("slot", range(0, slots_number))]) - self._stop = Method(i=[("slot", range(0, slots_number))]) - - # This bucket count gives us the best possible granularity. - bucket_count = bits_for(self.max_latency) + 1 - self.histogram = HwExpHistogram( - self.fully_qualified_name, - self.description, - bucket_count=bucket_count, - sample_width=bits_for(self.max_latency), - ) - - self.log = logging.HardwareLogger(fully_qualified_name) - - def elaborate(self, platform): - if not self.metrics_enabled(): - return TModule() - - m = TModule() - - epoch_width = bits_for(self.max_latency) - - m.submodules.slots = self.slots = AsyncMemoryBank( - data_layout=[("epoch", epoch_width)], elem_count=self.slots_number - ) - m.submodules.histogram = self.histogram - - slots_taken = Signal(self.slots_number) - slots_taken_start = Signal.like(slots_taken) - slots_taken_stop = Signal.like(slots_taken) - - m.d.comb += slots_taken_start.eq(slots_taken) - m.d.comb += slots_taken_stop.eq(slots_taken_start) - m.d.sync += slots_taken.eq(slots_taken_stop) - - epoch = Signal(epoch_width) - - m.d.sync += epoch.eq(epoch + 1) - - @def_method(m, self._start) - def _(slot: Value): - m.d.comb += slots_taken_start.eq(slots_taken | (1 << slot)) - self.log.error(m, (slots_taken & (1 << slot)).any(), "taken slot {} taken again", slot) - self.slots.write(m, addr=slot, data=epoch) - - @def_method(m, self._stop) - def _(slot: Value): - m.d.comb += slots_taken_stop.eq(slots_taken_start & ~(C(1, self.slots_number) << slot)) - self.log.error(m, ~(slots_taken & (1 << slot)).any(), "free slot {} freed again", slot) - ret = self.slots.read(m, addr=slot) - # The result of substracting two unsigned n-bit is a signed (n+1)-bit value, - # so we need to cast the result and discard the most significant bit. - duration = (epoch - ret.epoch).as_unsigned()[:-1] - self.histogram.add(m, duration) - - return m - - def start(self, m: TModule, *, slot: ValueLike): - """ - Registers the start of an event for a given slot tag. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - slot: ValueLike - The slot tag of the event. - """ - - if not self.metrics_enabled(): - return - - self._start(m, slot) - - def stop(self, m: TModule, *, slot: ValueLike): - """ - Registers the end of the event for a given slot tag. - - Should be called in the body of either a transaction or a method. - - Parameters - ---------- - m: TModule - Transactron module - slot: ValueLike - The slot tag of the event. - """ - - if not self.metrics_enabled(): - return - - self._stop(m, slot) - - def metrics_enabled(self) -> bool: - return DependencyContext.get().get_dependency(HwMetricsEnabledKey()) - - -class HardwareMetricsManager: - """ - Collects all metrics registered in the circuit and provides an easy - access to them. - """ - - def __init__(self): - self._metrics: Optional[dict[str, HwMetric]] = None - - def _collect_metrics(self) -> dict[str, HwMetric]: - # We lazily collect all metrics so that the metrics manager can be - # constructed at any time. Otherwise, if a metric object was created - # after the manager object had been created, that metric wouldn't end up - # being registered. - metrics: dict[str, HwMetric] = {} - for metric in DependencyContext.get().get_dependency(HwMetricsListKey()): - if metric.fully_qualified_name in metrics: - raise RuntimeError(f"Metric '{metric.fully_qualified_name}' is already registered") - - metrics[metric.fully_qualified_name] = metric - - return metrics - - def get_metrics(self) -> dict[str, HwMetric]: - """ - Returns all metrics registered in the circuit. - """ - if self._metrics is None: - self._metrics = self._collect_metrics() - return self._metrics - - def get_register_value(self, metric_name: str, reg_name: str) -> Signal: - """ - Returns the signal holding the register value of the given metric. - - Parameters - ---------- - metric_name: str - The fully qualified name of the metric, for example 'frontend.icache.loads'. - reg_name: str - The name of the register from that metric, for example if - the metric is a histogram, the 'reg_name' could be 'min' - or 'bucket-32'. - """ - - metrics = self.get_metrics() - if metric_name not in metrics: - raise RuntimeError(f"Couldn't find metric '{metric_name}'") - return metrics[metric_name].signals[reg_name] - - def debug_signals(self) -> SignalBundle: - """ - Returns tree-like SignalBundle composed of all metric registers. - """ - metrics = self.get_metrics() - - def rec(metric_names: list[str], depth: int = 1): - bundle: list[SignalBundle] = [] - components: dict[str, list[str]] = {} - - for metric in metric_names: - parts = metric.split(".") - - if len(parts) == depth: - signals = metrics[metric].signals - reg_values = [signals[reg_name] for reg_name in signals] - - bundle.append({metric: reg_values}) - - continue - - component_prefix = ".".join(parts[:depth]) - - if component_prefix not in components: - components[component_prefix] = [] - components[component_prefix].append(metric) - - for component_name, elements in components.items(): - bundle.append({component_name: rec(elements, depth + 1)}) - - return bundle - - return {"metrics": rec(list(self.get_metrics().keys()))} diff --git a/transactron/lib/reqres.py b/transactron/lib/reqres.py deleted file mode 100644 index a3f6e2908..000000000 --- a/transactron/lib/reqres.py +++ /dev/null @@ -1,185 +0,0 @@ -from amaranth import * -from ..core import * -from ..utils import SrcLoc, get_src_loc, MethodLayout -from .connectors import Forwarder -from transactron.lib import BasicFifo -from amaranth.utils import * - -__all__ = [ - "ArgumentsToResultsZipper", - "Serializer", -] - - -class ArgumentsToResultsZipper(Elaboratable): - """Zips arguments used to call method with results, cutting critical path. - - This module provides possibility to pass arguments from caller and connect it with results - from callee. Arguments are stored in 2-FIFO and results in Forwarder. Because of this asymmetry, - the callee should provide results as long as they aren't correctly received. - - FIFO is used as rate-limiter, so when FIFO reaches full capacity there should be no new requests issued. - - Example topology: - - .. mermaid:: - - graph LR - Caller -- write_arguments --> 2-FIFO; - Caller -- invoke --> Callee["Callee \\n (1+ cycle delay)"]; - Callee -- write_results --> Forwarder; - Forwarder -- read --> Zip; - 2-FIFO -- read --> Zip; - Zip -- read --> User; - subgraph ArgumentsToResultsZipper - Forwarder; - 2-FIFO; - Zip; - end - - Attributes - ---------- - peek_arg: Method - A nonexclusive method to read (but not delete) the head of the arg queue. - write_args: Method - Method to write arguments with `args_layout` format to 2-FIFO. - write_results: Method - Method to save results with `results_layout` in the Forwarder. - read: Method - Reads latest entries from the fifo and the forwarder and return them as - a structure with two fields: 'args' and 'results'. - """ - - def __init__(self, args_layout: MethodLayout, results_layout: MethodLayout, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - args_layout: method layout - The format of arguments. - results_layout: method layout - The format of results. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.src_loc = get_src_loc(src_loc) - self.results_layout = results_layout - self.args_layout = args_layout - self.output_layout = [("args", self.args_layout), ("results", results_layout)] - - self.peek_arg = Method(o=self.args_layout, nonexclusive=True, src_loc=self.src_loc) - self.write_args = Method(i=self.args_layout, src_loc=self.src_loc) - self.write_results = Method(i=self.results_layout, src_loc=self.src_loc) - self.read = Method(o=self.output_layout, src_loc=self.src_loc) - - def elaborate(self, platform): - m = TModule() - - fifo = BasicFifo(self.args_layout, depth=2, src_loc=self.src_loc) - forwarder = Forwarder(self.results_layout, src_loc=self.src_loc) - - m.submodules.fifo = fifo - m.submodules.forwarder = forwarder - - @def_method(m, self.write_args) - def _(arg): - fifo.write(m, arg) - - @def_method(m, self.write_results) - def _(arg): - forwarder.write(m, arg) - - @def_method(m, self.read) - def _(): - args = fifo.read(m) - results = forwarder.read(m) - return {"args": args, "results": results} - - self.peek_arg.proxy(m, fifo.peek) - - return m - - -class Serializer(Elaboratable): - """Module to serialize request-response methods. - - Provides a transactional interface to connect many client `Module`\\s (which request somethig using method call) - with a server `Module` which provides method to request operation and method to get response. - - Requests are being serialized from many clients and forwarded to a server which can process only one request - at the time. Responses from server are deserialized and passed to proper client. `Serializer` assumes, that - responses from the server are in-order, so the order of responses is the same as order of requests. - - - Attributes - ---------- - serialize_in: list[Method] - List of request methods. Data layouts are the same as for `serialized_req_method`. - serialize_out: list[Method] - List of response methods. Data layouts are the same as for `serialized_resp_method`. - `i`-th response method provides responses for requests from `i`-th `serialize_in` method. - """ - - def __init__( - self, - *, - port_count: int, - serialized_req_method: Method, - serialized_resp_method: Method, - depth: int = 4, - src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - port_count: int - Number of ports, which should be generated. `len(serialize_in)=len(serialize_out)=port_count` - serialized_req_method: Method - Request method provided by server's `Module`. - serialized_resp_method: Method - Response method provided by server's `Module`. - depth: int - Number of requests which can be forwarded to server, before server provides first response. Describe - the resistance of `Serializer` to latency of server in case when server is fully pipelined. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.src_loc = get_src_loc(src_loc) - self.port_count = port_count - self.serialized_req_method = serialized_req_method - self.serialized_resp_method = serialized_resp_method - - self.depth = depth - - self.id_layout = [("id", exact_log2(self.port_count))] - - self.clear = Method() - self.serialize_in = [ - Method.like(self.serialized_req_method, src_loc=self.src_loc) for _ in range(self.port_count) - ] - self.serialize_out = [ - Method.like(self.serialized_resp_method, src_loc=self.src_loc) for _ in range(self.port_count) - ] - - def elaborate(self, platform) -> TModule: - m = TModule() - - pending_requests = BasicFifo(self.id_layout, self.depth, src_loc=self.src_loc) - m.submodules.pending_requests = pending_requests - - for i in range(self.port_count): - - @def_method(m, self.serialize_in[i]) - def _(arg): - pending_requests.write(m, {"id": i}) - self.serialized_req_method(m, arg) - - @def_method(m, self.serialize_out[i], ready=(pending_requests.head.id == i)) - def _(): - pending_requests.read(m) - return self.serialized_resp_method(m) - - self.clear.proxy(m, pending_requests.clear) - - return m diff --git a/transactron/lib/simultaneous.py b/transactron/lib/simultaneous.py deleted file mode 100644 index 7b00f93ff..000000000 --- a/transactron/lib/simultaneous.py +++ /dev/null @@ -1,87 +0,0 @@ -from amaranth import * - -from ..utils import SrcLoc -from ..core import * -from ..core import TransactionBase -from contextlib import contextmanager -from typing import Optional -from transactron.utils import ValueLike - -__all__ = [ - "condition", -] - - -@contextmanager -def condition(m: TModule, *, nonblocking: bool = False, priority: bool = False): - """Conditions using simultaneous transactions. - - This context manager allows to easily define conditions utilizing - nested transactions and the simultaneous transactions mechanism. - It is similar to Amaranth's `If`, but allows to call different and - possibly overlapping method sets in each branch. Each of the branches is - defined using a separate nested transaction. - - Inside the condition body, branches can be added, which are guarded - by Boolean conditions. A branch is considered for execution if its - condition is true and the called methods can be run. A catch-all, - default branch can be added, which can be executed only if none of - the other branches execute. The condition of the default branch is - the negated alternative of all the other conditions. - - Parameters - ---------- - m : TModule - A module where the condition is defined. - nonblocking : bool - States that the condition should not block the containing method - or transaction from running, even when none of the branch - conditions is true. In case of a blocking method call, the - containing method or transaction is still blocked. - priority : bool - States that when conditions are not mutually exclusive and multiple - branches could be executed, the first one will be selected. This - influences the scheduling order of generated transactions. - - Examples - -------- - .. highlight:: python - .. code-block:: python - - with condition(m) as branch: - with branch(cond1): - ... - with branch(cond2): - ... - with branch(): # default, optional - ... - """ - this = TransactionBase.get() - transactions = list[Transaction]() - last = False - conds = list[Signal]() - - @contextmanager - def branch(cond: Optional[ValueLike] = None, *, src_loc: int | SrcLoc = 2): - nonlocal last - if last: - raise RuntimeError("Condition clause added after catch-all") - req = Signal() - m.d.top_comb += req.eq(cond if cond is not None else ~Cat(*conds).any()) - conds.append(req) - name = f"{this.name}_cond{len(transactions)}" - with (transaction := Transaction(name=name, src_loc=src_loc)).body(m, request=req): - yield - if transactions and priority: - transactions[-1].schedule_before(transaction) - if cond is None: - last = True - transactions.append(transaction) - - yield branch - - if nonblocking and not last: - with branch(): - pass - - this.simultaneous_alternatives(*transactions) diff --git a/transactron/lib/storage.py b/transactron/lib/storage.py deleted file mode 100644 index e402df7ec..000000000 --- a/transactron/lib/storage.py +++ /dev/null @@ -1,345 +0,0 @@ -from amaranth import * -from amaranth.utils import * -import amaranth.lib.memory as memory - -from transactron.utils.transactron_helpers import from_method_layout, make_layout -from ..core import * -from ..utils import SrcLoc, get_src_loc, MultiPriorityEncoder -from typing import Optional -from transactron.utils import LayoutList, MethodLayout - -__all__ = ["MemoryBank", "ContentAddressableMemory", "AsyncMemoryBank"] - - -class MemoryBank(Elaboratable): - """MemoryBank module. - - Provides a transactional interface to synchronous Amaranth Memory with arbitrary - number of read and write ports. It supports optionally writing with given granularity. - - Attributes - ---------- - read_reqs: list[Method] - The read request methods, one for each read port. Accepts an `addr` from which data should be read. - Only ready if there is there is a place to buffer response. After calling `read_reqs[i]`, the result - will be available via the method `read_resps[i]`. - read_resps: list[Method] - The read response methods, one for each read port. Return `data_layout` View which was saved on `addr` given - by last corresponding `read_reqs` method call. Only ready after corresponding `read_reqs` call. - writes: list[Method] - The write methods, one for each write port. Accepts write address `addr`, `data` in form of `data_layout` - and optionally `mask` if `granularity` is not None. `1` in mask means that appropriate part should be written. - read_req: Method - The only method from `read_reqs`, if the memory has a single read port. If it has more ports, this method - is unavailable and `read_reqs` should be used instead. - read_resp: Method - The only method from `read_resps`, if the memory has a single read port. If it has more ports, this method - is unavailable and `read_resps` should be used instead. - write: Method - The only method from `writes`, if the memory has a single write port. If it has more ports, this method - is unavailable and `writes` should be used instead. - """ - - def __init__( - self, - *, - data_layout: LayoutList, - elem_count: int, - granularity: Optional[int] = None, - transparent: bool = False, - read_ports: int = 1, - write_ports: int = 1, - src_loc: int | SrcLoc = 0, - ): - """ - Parameters - ---------- - data_layout: method layout - The format of structures stored in the Memory. - elem_count: int - Number of elements stored in Memory. - granularity: Optional[int] - Granularity of write, forwarded to Amaranth. If `None` the whole structure is always saved at once. - If not, the width of `data_layout` is split into `granularity` parts, which can be saved independently. - transparent: bool - Read port transparency, false by default. When a read port is transparent, if a given memory address - is read and written in the same clock cycle, the read returns the written value instead of the value - which was in the memory in that cycle. - read_ports: int - Number of read ports. - write_ports: int - Number of write ports. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.src_loc = get_src_loc(src_loc) - self.data_layout = make_layout(*data_layout) - self.elem_count = elem_count - self.granularity = granularity - self.width = from_method_layout(self.data_layout).size - self.addr_width = bits_for(self.elem_count - 1) - self.transparent = transparent - self.reads_ports = read_ports - self.writes_ports = write_ports - - self.read_reqs_layout: LayoutList = [("addr", self.addr_width)] - write_layout = [("addr", self.addr_width), ("data", self.data_layout)] - if self.granularity is not None: - write_layout.append(("mask", self.width // self.granularity)) - self.writes_layout = make_layout(*write_layout) - - self.read_reqs = [Method(i=self.read_reqs_layout, src_loc=self.src_loc) for _ in range(read_ports)] - self.read_resps = [Method(o=self.data_layout, src_loc=self.src_loc) for _ in range(read_ports)] - self.writes = [Method(i=self.writes_layout, src_loc=self.src_loc) for _ in range(write_ports)] - - if read_ports == 1: - self.read_req = self.read_reqs[0] - self.read_resp = self.read_resps[0] - if write_ports == 1: - self.write = self.writes[0] - - def elaborate(self, platform) -> TModule: - m = TModule() - - m.submodules.mem = mem = memory.Memory(shape=self.width, depth=self.elem_count, init=[]) - write_port = [mem.write_port() for _ in range(self.writes_ports)] - read_port = [ - mem.read_port(transparent_for=write_port if self.transparent else []) for _ in range(self.reads_ports) - ] - read_output_valid = [Signal() for _ in range(self.reads_ports)] - overflow_valid = [Signal() for _ in range(self.reads_ports)] - overflow_data = [Signal(self.width) for _ in range(self.reads_ports)] - - # The read request method can be called at most twice when not reading the response. - # The first result is stored in the overflow buffer, the second - in the read value buffer of the memory. - # If the responses are always read as they arrive, overflow is never written and no stalls occur. - - for i in range(self.reads_ports): - with m.If(read_output_valid[i] & ~overflow_valid[i] & self.read_reqs[i].run & ~self.read_resps[i].run): - m.d.sync += overflow_valid[i].eq(1) - m.d.sync += overflow_data[i].eq(read_port[i].data) - - @def_methods(m, self.read_resps, lambda i: read_output_valid[i] | overflow_valid[i]) - def _(i: int): - with m.If(overflow_valid[i]): - m.d.sync += overflow_valid[i].eq(0) - with m.Else(): - m.d.sync += read_output_valid[i].eq(0) - return Mux(overflow_valid[i], overflow_data[i], read_port[i].data) - - for i in range(self.reads_ports): - m.d.comb += read_port[i].en.eq(0) # because the init value is 1 - - @def_methods(m, self.read_reqs, lambda i: ~overflow_valid[i]) - def _(i: int, addr): - m.d.sync += read_output_valid[i].eq(1) - m.d.comb += read_port[i].en.eq(1) - m.d.comb += read_port[i].addr.eq(addr) - - @def_methods(m, self.writes) - def _(i: int, arg): - m.d.comb += write_port[i].addr.eq(arg.addr) - m.d.comb += write_port[i].data.eq(arg.data) - if self.granularity is None: - m.d.comb += write_port[i].en.eq(1) - else: - m.d.comb += write_port[i].en.eq(arg.mask) - - return m - - -class ContentAddressableMemory(Elaboratable): - """Content addresable memory - - This module implements a content-addressable memory (in short CAM) with Transactron interface. - CAM is a type of memory where instead of predefined indexes there are used values fed in runtime - as keys (similar as in python dictionary). To insert new entry a pair `(key, value)` has to be - provided. Such pair takes an free slot which depends on internal implementation. To read value - a `key` has to be provided. It is compared with every valid key stored in CAM. If there is a hit, - a value is read. There can be many instances of the same key in CAM. In such case it is undefined - which value will be read. - - - .. warning:: - Pushing the value with index already present in CAM is an undefined behaviour. - - Attributes - ---------- - read : Method - Nondestructive read - write : Method - If index present - do update - remove : Method - Remove - push : Method - Inserts new data. - """ - - def __init__(self, address_layout: MethodLayout, data_layout: MethodLayout, entries_number: int): - """ - Parameters - ---------- - address_layout : LayoutLike - The layout of the address records. - data_layout : LayoutLike - The layout of the data. - entries_number : int - The number of slots to create in memory. - """ - self.address_layout = from_method_layout(address_layout) - self.data_layout = from_method_layout(data_layout) - self.entries_number = entries_number - - self.read = Method(i=[("addr", self.address_layout)], o=[("data", self.data_layout), ("not_found", 1)]) - self.remove = Method(i=[("addr", self.address_layout)]) - self.push = Method(i=[("addr", self.address_layout), ("data", self.data_layout)]) - self.write = Method(i=[("addr", self.address_layout), ("data", self.data_layout)], o=[("not_found", 1)]) - - def elaborate(self, platform) -> TModule: - m = TModule() - - address_array = Array( - [Signal(self.address_layout, name=f"address_array_{i}") for i in range(self.entries_number)] - ) - data_array = Array([Signal(self.data_layout, name=f"data_array_{i}") for i in range(self.entries_number)]) - valids = Signal(self.entries_number, name="valids") - - m.submodules.encoder_read = encoder_read = MultiPriorityEncoder(self.entries_number, 1) - m.submodules.encoder_write = encoder_write = MultiPriorityEncoder(self.entries_number, 1) - m.submodules.encoder_push = encoder_push = MultiPriorityEncoder(self.entries_number, 1) - m.submodules.encoder_remove = encoder_remove = MultiPriorityEncoder(self.entries_number, 1) - m.d.top_comb += encoder_push.input.eq(~valids) - - @def_method(m, self.push, ready=~valids.all()) - def _(addr, data): - id = Signal(range(self.entries_number), name="id_push") - m.d.top_comb += id.eq(encoder_push.outputs[0]) - m.d.sync += address_array[id].eq(addr) - m.d.sync += data_array[id].eq(data) - m.d.sync += valids.bit_select(id, 1).eq(1) - - @def_method(m, self.write) - def _(addr, data): - write_mask = Signal(self.entries_number, name="write_mask") - m.d.top_comb += write_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids) - m.d.top_comb += encoder_write.input.eq(write_mask) - with m.If(write_mask.any()): - m.d.sync += data_array[encoder_write.outputs[0]].eq(data) - return {"not_found": ~write_mask.any()} - - @def_method(m, self.read) - def _(addr): - read_mask = Signal(self.entries_number, name="read_mask") - m.d.top_comb += read_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids) - m.d.top_comb += encoder_read.input.eq(read_mask) - return {"data": data_array[encoder_read.outputs[0]], "not_found": ~read_mask.any()} - - @def_method(m, self.remove) - def _(addr): - rm_mask = Signal(self.entries_number, name="rm_mask") - m.d.top_comb += rm_mask.eq(Cat([addr == stored_addr for stored_addr in address_array]) & valids) - m.d.top_comb += encoder_remove.input.eq(rm_mask) - with m.If(rm_mask.any()): - m.d.sync += valids.bit_select(encoder_remove.outputs[0], 1).eq(0) - - return m - - -class AsyncMemoryBank(Elaboratable): - """AsyncMemoryBank module. - - Provides a transactional interface to asynchronous Amaranth Memory with arbitrary number of - read and write ports. It supports optionally writing with given granularity. - - Attributes - ---------- - reads: list[Method] - The read methods, one for each read port. Accepts an `addr` from which data should be read. - The read response method. Return `data_layout` View which was saved on `addr` given by last - `write` method call. - writes: list[Method] - The write methods, one for each write port. Accepts write address `addr`, `data` in form of `data_layout` - and optionally `mask` if `granularity` is not None. `1` in mask means that appropriate part should be written. - read: Method - The only method from `reads`, if the memory has a single read port. - write: Method - The only method from `writes`, if the memory has a single write port. - """ - - def __init__( - self, - *, - data_layout: LayoutList, - elem_count: int, - granularity: Optional[int] = None, - read_ports: int = 1, - write_ports: int = 1, - src_loc: int | SrcLoc = 0, - ): - """ - Parameters - ---------- - data_layout: method layout - The format of structures stored in the Memory. - elem_count: int - Number of elements stored in Memory. - granularity: Optional[int] - Granularity of write, forwarded to Amaranth. If `None` the whole structure is always saved at once. - If not, the width of `data_layout` is split into `granularity` parts, which can be saved independently. - read_ports: int - Number of read ports. - write_ports: int - Number of write ports. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.src_loc = get_src_loc(src_loc) - self.data_layout = make_layout(*data_layout) - self.elem_count = elem_count - self.granularity = granularity - self.width = from_method_layout(self.data_layout).size - self.addr_width = bits_for(self.elem_count - 1) - self.reads_ports = read_ports - self.writes_ports = write_ports - - self.read_reqs_layout: LayoutList = [("addr", self.addr_width)] - write_layout = [("addr", self.addr_width), ("data", self.data_layout)] - if self.granularity is not None: - write_layout.append(("mask", self.width // self.granularity)) - self.writes_layout = make_layout(*write_layout) - - self.reads = [ - Method(i=self.read_reqs_layout, o=self.data_layout, src_loc=self.src_loc) for _ in range(read_ports) - ] - self.writes = [Method(i=self.writes_layout, src_loc=self.src_loc) for _ in range(write_ports)] - - if read_ports == 1: - self.read = self.reads[0] - if write_ports == 1: - self.write = self.writes[0] - - def elaborate(self, platform) -> TModule: - m = TModule() - - mem = memory.Memory(shape=self.width, depth=self.elem_count, init=[]) - m.submodules.mem = mem - write_port = [mem.write_port() for _ in range(self.writes_ports)] - read_port = [mem.read_port(domain="comb") for _ in range(self.reads_ports)] - - @def_methods(m, self.reads) - def _(i: int, addr): - m.d.comb += read_port[i].addr.eq(addr) - return read_port[i].data - - @def_methods(m, self.writes) - def _(i: int, arg): - m.d.comb += write_port[i].addr.eq(arg.addr) - m.d.comb += write_port[i].data.eq(arg.data) - if self.granularity is None: - m.d.comb += write_port[i].en.eq(1) - else: - m.d.comb += write_port[i].en.eq(arg.mask) - - return m diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py deleted file mode 100644 index cd09816a2..000000000 --- a/transactron/lib/transformers.py +++ /dev/null @@ -1,452 +0,0 @@ -from amaranth import * - -from transactron.utils.transactron_helpers import get_src_loc -from ..core import * -from ..utils import SrcLoc -from typing import Optional, Protocol -from collections.abc import Callable -from transactron.utils import ( - ValueLike, - assign, - AssignType, - ModuleLike, - MethodStruct, - HasElaborate, - MethodLayout, - RecordDict, -) -from .connectors import Forwarder, ManyToOneConnectTrans, ConnectTrans -from .simultaneous import condition - -__all__ = [ - "Transformer", - "Unifier", - "MethodMap", - "MethodFilter", - "MethodProduct", - "MethodTryProduct", - "Collector", - "CatTrans", - "ConnectAndMapTrans", -] - - -class Transformer(HasElaborate, Protocol): - """Method transformer abstract class. - - Method transformers construct a new method which utilizes other methods. - - Attributes - ---------- - method: Method - The method. - """ - - method: Method - - def use(self, m: ModuleLike): - """ - Returns the method and adds the transformer to a module. - - Parameters - ---------- - m: Module or TModule - The module to which this transformer is added as a submodule. - """ - m.submodules += self - return self.method - - -class Unifier(Transformer, Protocol): - method: Method - - def __init__(self, targets: list[Method]): ... - - -class MethodMap(Elaboratable, Transformer): - """Bidirectional map for methods. - - Takes a target method and creates a transformed method which calls the - original target method, mapping the input and output values with - functions. The mapping functions take two parameters, a `Module` and the - structure being transformed. Alternatively, a `Method` can be - passed. - - Attributes - ---------- - method: Method - The transformed method. - """ - - def __init__( - self, - target: Method, - *, - i_transform: Optional[tuple[MethodLayout, Callable[[TModule, MethodStruct], RecordDict]]] = None, - o_transform: Optional[tuple[MethodLayout, Callable[[TModule, MethodStruct], RecordDict]]] = None, - src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - target: Method - The target method. - i_transform: (method layout, function or Method), optional - Input mapping function. If specified, it should be a pair of a - function and a input layout for the transformed method. - If not present, input is passed unmodified. - o_transform: (method layout, function or Method), optional - Output mapping function. If specified, it should be a pair of a - function and a output layout for the transformed method. - If not present, output is passed unmodified. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - if i_transform is None: - i_transform = (target.layout_in, lambda _, x: x) - if o_transform is None: - o_transform = (target.layout_out, lambda _, x: x) - - self.target = target - src_loc = get_src_loc(src_loc) - self.method = Method(i=i_transform[0], o=o_transform[0], src_loc=src_loc) - self.i_fun = i_transform[1] - self.o_fun = o_transform[1] - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.method) - def _(arg): - return self.o_fun(m, self.target(m, self.i_fun(m, arg))) - - return m - - -class MethodFilter(Elaboratable, Transformer): - """Method filter. - - Takes a target method and creates a method which calls the target method - only when some condition is true. The condition function takes two - parameters, a module and the input structure of the method. Non-zero - return value is interpreted as true. Alternatively to using a function, - a `Method` can be passed as a condition. - By default, the target method is locked for use even if it is not called. - If this is not the desired effect, set `use_condition` to True, but this will - cause that the provided method will be `single_caller` and all other `condition` - drawbacks will be in place (e.g. risk of exponential complexity). - - Attributes - ---------- - method: Method - The transformed method. - """ - - def __init__( - self, - target: Method, - condition: Callable[[TModule, MethodStruct], ValueLike], - default: Optional[RecordDict] = None, - *, - use_condition: bool = False, - src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - target: Method - The target method. - condition: function or Method - The condition which, when true, allows the call to `target`. When - false, `default` is returned. - default: Value or dict, optional - The default value returned from the filtered method when the condition - is false. If omitted, zero is returned. - use_condition : bool - Instead of `m.If` use simultaneus `condition` which allow to execute - this filter if the condition is False and target is not ready. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - if default is None: - default = Signal.like(target.data_out) - - self.target = target - self.use_condition = use_condition - src_loc = get_src_loc(src_loc) - self.method = Method(i=target.layout_in, o=target.layout_out, single_caller=self.use_condition, src_loc=src_loc) - self.condition = condition - self.default = default - - def elaborate(self, platform): - m = TModule() - - ret = Signal.like(self.target.data_out) - m.d.comb += assign(ret, self.default, fields=AssignType.ALL) - - @def_method(m, self.method) - def _(arg): - if self.use_condition: - cond = Signal() - m.d.top_comb += cond.eq(self.condition(m, arg)) - with condition(m, nonblocking=True) as branch: - with branch(cond): - m.d.comb += ret.eq(self.target(m, arg)) - else: - with m.If(self.condition(m, arg)): - m.d.comb += ret.eq(self.target(m, arg)) - return ret - - return m - - -class MethodProduct(Elaboratable, Unifier): - def __init__( - self, - targets: list[Method], - combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[MethodStruct]], RecordDict]]] = None, - *, - src_loc: int | SrcLoc = 0 - ): - """Method product. - - Takes arbitrary, non-zero number of target methods, and constructs - a method which calls all of the target methods using the same - argument. The return value of the resulting method is, by default, - the return value of the first of the target methods. A combiner - function can be passed, which can compute the return value from - the results of every target method. - - Parameters - ---------- - targets: list[Method] - A list of methods to be called. - combiner: (int or method layout, function), optional - A pair of the output layout and the combiner function. The - combiner function takes two parameters: a `Module` and - a list of outputs of the target methods. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - - Attributes - ---------- - method: Method - The product method. - """ - if combiner is None: - combiner = (targets[0].layout_out, lambda _, x: x[0]) - self.targets = targets - self.combiner = combiner - src_loc = get_src_loc(src_loc) - self.method = Method(i=targets[0].layout_in, o=combiner[0], src_loc=src_loc) - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.method) - def _(arg): - results = [] - for target in self.targets: - results.append(target(m, arg)) - return self.combiner[1](m, results) - - return m - - -class MethodTryProduct(Elaboratable, Unifier): - def __init__( - self, - targets: list[Method], - combiner: Optional[ - tuple[MethodLayout, Callable[[TModule, list[tuple[Value, MethodStruct]]], RecordDict]] - ] = None, - *, - src_loc: int | SrcLoc = 0 - ): - """Method product with optional calling. - - Takes arbitrary, non-zero number of target methods, and constructs - a method which tries to call all of the target methods using the same - argument. The methods which are not ready are not called. The return - value of the resulting method is, by default, empty. A combiner - function can be passed, which can compute the return value from the - results of every target method. - - Parameters - ---------- - targets: list[Method] - A list of methods to be called. - combiner: (int or method layout, function), optional - A pair of the output layout and the combiner function. The - combiner function takes two parameters: a `Module` and - a list of pairs. Each pair contains a bit which signals - that a given call succeeded, and the result of the call. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - - Attributes - ---------- - method: Method - The product method. - """ - if combiner is None: - combiner = ([], lambda _, __: {}) - self.targets = targets - self.combiner = combiner - self.src_loc = get_src_loc(src_loc) - self.method = Method(i=targets[0].layout_in, o=combiner[0], src_loc=self.src_loc) - - def elaborate(self, platform): - m = TModule() - - @def_method(m, self.method) - def _(arg): - results: list[tuple[Value, MethodStruct]] = [] - for target in self.targets: - success = Signal() - with Transaction(src_loc=self.src_loc).body(m): - m.d.comb += success.eq(1) - results.append((success, target(m, arg))) - return self.combiner[1](m, results) - - return m - - -class Collector(Elaboratable, Unifier): - """Single result collector. - - Creates method that collects results of many methods with identical - layouts. Each call of this method will return a single result of one - of the provided methods. - - Attributes - ---------- - method: Method - Method which returns single result of provided methods. - """ - - def __init__(self, targets: list[Method], *, src_loc: int | SrcLoc = 0): - """ - Parameters - ---------- - method_list: list[Method] - List of methods from which results will be collected. - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.method_list = targets - layout = targets[0].layout_out - self.src_loc = get_src_loc(src_loc) - self.method = Method(o=layout, src_loc=self.src_loc) - - for method in targets: - if layout != method.layout_out: - raise Exception("Not all methods have this same layout") - - def elaborate(self, platform): - m = TModule() - - m.submodules.forwarder = forwarder = Forwarder(self.method.layout_out, src_loc=self.src_loc) - - m.submodules.connect = ManyToOneConnectTrans( - get_results=[get for get in self.method_list], put_result=forwarder.write, src_loc=self.src_loc - ) - - self.method.proxy(m, forwarder.read) - - return m - - -class CatTrans(Elaboratable): - """Concatenating transaction. - - Concatenates the results of two methods and passes the result to the - third method. - """ - - def __init__(self, src1: Method, src2: Method, dst: Method): - """ - Parameters - ---------- - src1: Method - First input method. - src2: Method - Second input method. - dst: Method - The method which receives the concatenation of the results of input - methods. - """ - self.src1 = src1 - self.src2 = src2 - self.dst = dst - - def elaborate(self, platform): - m = TModule() - - with Transaction().body(m): - sdata1 = self.src1(m) - sdata2 = self.src2(m) - ddata = Signal.like(self.dst.data_in) - self.dst(m, ddata) - - m.d.comb += ddata.eq(Cat(sdata1, sdata2)) - - return m - - -class ConnectAndMapTrans(Elaboratable): - """Connecting transaction with mapping functions. - - Behaves like `ConnectTrans`, but modifies the transferred data using - functions or `Method`s. Equivalent to a combination of `ConnectTrans` - and `MethodMap`. The mapping functions take two parameters, a `Module` - and the structure being transformed. - """ - - def __init__( - self, - method1: Method, - method2: Method, - *, - i_fun: Optional[Callable[[TModule, MethodStruct], RecordDict]] = None, - o_fun: Optional[Callable[[TModule, MethodStruct], RecordDict]] = None, - src_loc: int | SrcLoc = 0 - ): - """ - Parameters - ---------- - method1: Method - First method. - method2: Method - Second method, and the method being transformed. - i_fun: function or Method, optional - Input transformation (`method1` to `method2`). - o_fun: function or Method, optional - Output transformation (`method2` to `method1`). - src_loc: int | SrcLoc - How many stack frames deep the source location is taken from. - Alternatively, the source location to use instead of the default. - """ - self.method1 = method1 - self.method2 = method2 - self.i_fun = i_fun or (lambda _, x: x) - self.o_fun = o_fun or (lambda _, x: x) - self.src_loc = get_src_loc(src_loc) - - def elaborate(self, platform): - m = TModule() - - m.submodules.transformer = transformer = MethodMap( - self.method2, - i_transform=(self.method1.layout_out, self.i_fun), - o_transform=(self.method1.layout_in, self.o_fun), - src_loc=self.src_loc, - ) - m.submodules.connect = ConnectTrans(self.method1, transformer.method) - - return m diff --git a/transactron/profiler.py b/transactron/profiler.py deleted file mode 100644 index fcea59387..000000000 --- a/transactron/profiler.py +++ /dev/null @@ -1,356 +0,0 @@ -import os -from collections import defaultdict -from typing import Optional -from dataclasses import dataclass, field -from dataclasses_json import dataclass_json -from transactron.utils import SrcLoc, IdGenerator -from transactron.core import TransactionManager -from transactron.core.manager import MethodMap - - -__all__ = [ - "ProfileInfo", - "ProfileData", - "RunStat", - "RunStatNode", - "Profile", - "TransactionSamples", - "MethodSamples", - "ProfileSamples", -] - - -@dataclass_json -@dataclass -class ProfileInfo: - """Information about transactions and methods. - - In `Profile`, transactions and methods are referred to by their unique ID - numbers. - - Attributes - ---------- - name : str - The name. - src_loc : SrcLoc - Source location. - is_transaction : bool - If true, this object describes a transaction; if false, a method. - """ - - name: str - src_loc: SrcLoc - is_transaction: bool - - -@dataclass -class ProfileData: - """Information about transactions and methods from the transaction manager. - - This data is required for transaction profile generation in simulators. - Transactions and methods are referred to by their unique ID numbers. - - Attributes - ---------- - transactions_and_methods: dict[int, ProfileInfo] - Information about individual transactions and methods. - method_parents: dict[int, list[int]] - Lists the callers (transactions and methods) for each method. Key is - method ID. - transactions_by_method: dict[int, list[int]] - Lists which transactions are calling each method. Key is method ID. - transaction_conflicts: dict[int, list[int]] - List which other transactions conflict with each transaction. - """ - - transactions_and_methods: dict[int, ProfileInfo] - method_parents: dict[int, list[int]] - transactions_by_method: dict[int, list[int]] - transaction_conflicts: dict[int, list[int]] - - @staticmethod - def make(transaction_manager: TransactionManager): - transactions_and_methods = dict[int, ProfileInfo]() - method_parents = dict[int, list[int]]() - transactions_by_method = dict[int, list[int]]() - transaction_conflicts = dict[int, list[int]]() - - method_map = MethodMap(transaction_manager.transactions) - cgr, _ = TransactionManager._conflict_graph(method_map) - get_id = IdGenerator() - - def local_src_loc(src_loc: SrcLoc): - return (os.path.relpath(src_loc[0]), src_loc[1]) - - for transaction in method_map.transactions: - transactions_and_methods[get_id(transaction)] = ProfileInfo( - transaction.owned_name, local_src_loc(transaction.src_loc), True - ) - - for method in method_map.methods: - transactions_and_methods[get_id(method)] = ProfileInfo( - method.owned_name, local_src_loc(method.src_loc), False - ) - method_parents[get_id(method)] = [get_id(t_or_m) for t_or_m in method_map.method_parents[method]] - transactions_by_method[get_id(method)] = [ - get_id(t_or_m) for t_or_m in method_map.transactions_by_method[method] - ] - - for transaction, transactions in cgr.items(): - transaction_conflicts[get_id(transaction)] = [get_id(transaction2) for transaction2 in transactions] - - return ( - ProfileData(transactions_and_methods, method_parents, transactions_by_method, transaction_conflicts), - get_id, - ) - - -@dataclass -class RunStat: - """Collected statistics about a transaction or method. - - Attributes - ---------- - name : str - The name. - src_loc : SrcLoc - Source location. - locked : int - For methods: the number of cycles this method was locked because of - a disabled call (a call under a false condition). For transactions: - the number of cycles this transaction was ready to run, but did not - run because a conflicting transaction has run instead. - """ - - name: str - src_loc: str - locked: int = 0 - run: int = 0 - - @staticmethod - def make(info: ProfileInfo): - return RunStat(info.name, f"{info.src_loc[0]}:{info.src_loc[1]}") - - -@dataclass -class RunStatNode: - """A statistics tree. Summarizes call graph information. - - Attributes - ---------- - stat : RunStat - Statistics. - callers : dict[int, RunStatNode] - Statistics for the method callers. For transactions, this is empty. - """ - - stat: RunStat - callers: dict[int, "RunStatNode"] = field(default_factory=dict) - - @staticmethod - def make(info: ProfileInfo): - return RunStatNode(RunStat.make(info)) - - -@dataclass -class TransactionSamples: - """Runtime value of transaction control signals in a given clock cycle. - - Attributes - ---------- - request: bool - The value of the transaction's ``request`` signal. - runnable: bool - The value of the transaction's ``runnable`` signal. - grant: bool - The value of the transaction's ``grant`` signal. - """ - - request: bool - runnable: bool - grant: bool - - -@dataclass -class MethodSamples: - """Runtime value of method control signals in a given clock cycle. - - Attributes - ---------- - run: bool - The value of the method's ``run`` signal. - """ - - run: bool - - -@dataclass -class ProfileSamples: - """Runtime values of all transaction and method control signals. - - Attributes - ---------- - transactions: dict[int, TransactionSamples] - Runtime values of transaction control signals for each transaction. - methods: dict[int, MethodSamples] - Runtime values of method control signals for each method. - """ - - transactions: dict[int, TransactionSamples] = field(default_factory=dict) - methods: dict[int, MethodSamples] = field(default_factory=dict) - - -@dataclass_json -@dataclass -class CycleProfile: - """Profile information for a single clock cycle. - - Transactions and methods are referred to by unique IDs. - - Attributes - ---------- - locked : dict[int, int] - For each transaction which didn't run because of a conflict, the - transaction which has run instead. For each method which was used - but didn't run because of a disabled call, the caller which - used it. - running : dict[int, Optional[int]] - For each running method, its caller. Running transactions don't - have a caller (the value is `None`). - """ - - locked: dict[int, int] = field(default_factory=dict) - running: dict[int, Optional[int]] = field(default_factory=dict) - - @staticmethod - def make(samples: ProfileSamples, data: ProfileData): - cprof = CycleProfile() - - for transaction_id, transaction_samples in samples.transactions.items(): - if transaction_samples.grant: - cprof.running[transaction_id] = None - elif transaction_samples.request and transaction_samples.runnable: - for transaction2_id in data.transaction_conflicts[transaction_id]: - if samples.transactions[transaction2_id].grant: - cprof.locked[transaction_id] = transaction2_id - - running = set(cprof.running) - for method_id, method_samples in samples.methods.items(): - if method_samples.run: - running.add(method_id) - - locked_methods = set[int]() - for method_id in samples.methods.keys(): - if method_id not in running: - if any(transaction_id in running for transaction_id in data.transactions_by_method[method_id]): - locked_methods.add(method_id) - - for method_id in samples.methods.keys(): - if method_id in running: - for t_or_m_id in data.method_parents[method_id]: - if t_or_m_id in running: - cprof.running[method_id] = t_or_m_id - elif method_id in locked_methods: - caller = next( - t_or_m_id - for t_or_m_id in data.method_parents[method_id] - if t_or_m_id in running or t_or_m_id in locked_methods - ) - cprof.locked[method_id] = caller - - return cprof - - -@dataclass_json -@dataclass -class Profile: - """Transactron execution profile. - - Can be saved by the simulator, and then restored by an analysis tool. - In the profile data structure, methods and transactions are referred to - by their unique ID numbers. - - Attributes - ---------- - transactions_and_methods : dict[int, ProfileInfo] - Information about transactions and methods indexed by ID numbers. - cycles : list[CycleProfile] - Profile information for each cycle of the simulation. - """ - - transactions_and_methods: dict[int, ProfileInfo] = field(default_factory=dict) - cycles: list[CycleProfile] = field(default_factory=list) - - def encode(self, file_name: str): - with open(file_name, "w") as fp: - fp.write(self.to_json()) # type: ignore - - @staticmethod - def decode(file_name: str) -> "Profile": - with open(file_name, "r") as fp: - return Profile.from_json(fp.read()) # type: ignore - - def analyze_transactions(self, recursive=False) -> list[RunStatNode]: - stats = {i: RunStatNode.make(info) for i, info in self.transactions_and_methods.items() if info.is_transaction} - - def rec(c: CycleProfile, node: RunStatNode, i: int): - if i in c.running: - node.stat.run += 1 - elif i in c.locked: - node.stat.locked += 1 - if recursive: - for j in called[i]: - if j not in node.callers: - node.callers[j] = RunStatNode.make(self.transactions_and_methods[j]) - rec(c, node.callers[j], j) - - for c in self.cycles: - called = defaultdict[int, set[int]](set) - - for i, j in c.running.items(): - if j is not None: - called[j].add(i) - - for i, j in c.locked.items(): - called[j].add(i) - - for i in c.running: - if i in stats: - rec(c, stats[i], i) - - for i in c.locked: - if i in stats: - stats[i].stat.locked += 1 - - return list(stats.values()) - - def analyze_methods(self, recursive=False) -> list[RunStatNode]: - stats = { - i: RunStatNode.make(info) for i, info in self.transactions_and_methods.items() if not info.is_transaction - } - - def rec(c: CycleProfile, node: RunStatNode, i: int, locking_call=False): - if i in c.running: - if not locking_call: - node.stat.run += 1 - else: - node.stat.locked += 1 - caller = c.running[i] - else: - node.stat.locked += 1 - caller = c.locked[i] - if recursive and caller is not None: - if caller not in node.callers: - node.callers[caller] = RunStatNode.make(self.transactions_and_methods[caller]) - rec(c, node.callers[caller], caller, locking_call) - - for c in self.cycles: - for i in c.running: - if i in stats: - rec(c, stats[i], i) - - for i in c.locked: - if i in stats: - rec(c, stats[i], i, locking_call=True) - - return list(stats.values()) diff --git a/transactron/testing/__init__.py b/transactron/testing/__init__.py deleted file mode 100644 index aa215228e..000000000 --- a/transactron/testing/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .input_generation import * # noqa: F401 -from .functions import * # noqa: F401 -from .infrastructure import * # noqa: F401 -from .sugar import * # noqa: F401 -from .testbenchio import * # noqa: F401 -from .profiler import * # noqa: F401 -from .logging import * # noqa: F401 -from transactron.utils import data_layout # noqa: F401 diff --git a/transactron/testing/functions.py b/transactron/testing/functions.py deleted file mode 100644 index 347a41bdc..000000000 --- a/transactron/testing/functions.py +++ /dev/null @@ -1,31 +0,0 @@ -from amaranth import * -from amaranth.lib.data import Layout, StructLayout, View -from amaranth.sim.core import Command -from typing import TypeVar, Any, Generator, TypeAlias, TYPE_CHECKING, Union -from transactron.utils._typing import RecordIntDict - - -if TYPE_CHECKING: - from amaranth.hdl._ast import Statement - from .infrastructure import CoreblocksCommand - - -T = TypeVar("T") -TestGen: TypeAlias = Generator[Union[Command, Value, "Statement", "CoreblocksCommand", None], Any, T] - - -def get_outputs(field: View) -> TestGen[RecordIntDict]: - # return dict of all signal values in a record because amaranth's simulator can't read all - # values of a View in a single yield - it can only read Values (Signals) - result = {} - layout = field.shape() - assert isinstance(layout, StructLayout) - for name, fld in layout: - val = field[name] - if isinstance(fld.shape, Layout): - result[name] = yield from get_outputs(View(fld.shape, val)) - elif isinstance(val, Value): - result[name] = yield val - else: - raise ValueError - return result diff --git a/transactron/testing/gtkw_extension.py b/transactron/testing/gtkw_extension.py deleted file mode 100644 index db407ac9d..000000000 --- a/transactron/testing/gtkw_extension.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Iterable, Mapping -from contextlib import contextmanager -from amaranth.lib.data import View -from amaranth.sim.pysim import _VCDWriter -from amaranth.sim import Tick -from amaranth import * -from transactron.utils import flatten_signals - - -class _VCDWriterExt(_VCDWriter): - def __init__(self, design, *, vcd_file, gtkw_file, traces): - super().__init__(design=design, vcd_file=vcd_file, gtkw_file=gtkw_file, traces=list(flatten_signals(traces))) - self._tree_traces = traces - - def close(self, timestamp): - def save_signal(value: Value): - for signal in value._rhs_signals(): # type: ignore - if signal in self.gtkw_signal_names: - for name in self.gtkw_signal_names[signal]: - self.gtkw_save.trace(name) - - def gtkw_traces(traces): - if isinstance(traces, Mapping): - for k, v in traces.items(): - with self.gtkw_save.group(k): - gtkw_traces(v) - elif isinstance(traces, Iterable): - for v in traces: - gtkw_traces(v) - elif isinstance(traces, Record): - if len(traces.fields) > 1: - with self.gtkw_save.group(traces.name): - for v in traces.fields.values(): - gtkw_traces(v) - elif len(traces.fields) == 1: # to make gtkwave view less verbose - gtkw_traces(next(iter(traces.fields.values()))) - elif isinstance(traces, View): - v = Value.cast(traces) - with self.gtkw_save.group(v.name if isinstance(v, Signal) else ""): - save_signal(v) - elif isinstance(traces, Value): - save_signal(traces) - - if self.vcd_writer is not None: - self.vcd_writer.close(timestamp) - - if self.gtkw_save is not None: - self.gtkw_save.dumpfile(self.vcd_file.name) - self.gtkw_save.dumpfile_size(self.vcd_file.tell()) - self.gtkw_save.zoom_markers(-21) - - self.gtkw_save.treeopen("top") - gtkw_traces(self._tree_traces) - - if self.vcd_file is not None: - self.vcd_file.close() - if self.gtkw_file is not None: - self.gtkw_file.close() - - -@contextmanager -def write_vcd_ext(engine, vcd_file, gtkw_file, traces): - vcd_writer = _VCDWriterExt(engine._design, vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces) - try: - engine._vcd_writers.append(vcd_writer) - yield Tick() - finally: - vcd_writer.close(engine.now) - engine._vcd_writers.remove(vcd_writer) diff --git a/transactron/testing/infrastructure.py b/transactron/testing/infrastructure.py deleted file mode 100644 index 861428fc3..000000000 --- a/transactron/testing/infrastructure.py +++ /dev/null @@ -1,365 +0,0 @@ -import sys -import pytest -import logging -import os -import random -import functools -import warnings -from contextlib import contextmanager, nullcontext -from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast, TypeAlias, Optional -from abc import ABC -from amaranth import * -from amaranth.sim import * - -from transactron.utils.dependencies import DependencyContext, DependencyManager -from .testbenchio import TestbenchIO -from .profiler import profiler_process, Profile -from .functions import TestGen -from .logging import make_logging_process, parse_logging_level, _LogFormatter -from .gtkw_extension import write_vcd_ext -from transactron import Method -from transactron.lib import AdapterTrans -from transactron.core.keys import TransactionManagerKey -from transactron.core import TransactionModule -from transactron.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals - - -T = TypeVar("T") -_T_nested_collection: TypeAlias = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"] - - -def guard_nested_collection(cont: Any, t: Type[T]) -> TypeGuard[_T_nested_collection[T]]: - if isinstance(cont, (list, dict)): - if isinstance(cont, dict): - cont = cont.values() - return all([guard_nested_collection(elem, t) for elem in cont]) - elif isinstance(cont, t): - return True - else: - return False - - -_T_HasElaborate = TypeVar("_T_HasElaborate", bound=HasElaborate) - - -class SimpleTestCircuit(Elaboratable, Generic[_T_HasElaborate]): - def __init__(self, dut: _T_HasElaborate): - self._dut = dut - self._io: dict[str, _T_nested_collection[TestbenchIO]] = {} - - def __getattr__(self, name: str) -> Any: - try: - return self._io[name] - except KeyError: - raise AttributeError(f"No mock for '{name}'") - - def elaborate(self, platform): - def transform_methods_to_testbenchios( - container: _T_nested_collection[Method], - ) -> tuple[_T_nested_collection["TestbenchIO"], Union[ModuleConnector, "TestbenchIO"]]: - if isinstance(container, list): - tb_list = [] - mc_list = [] - for elem in container: - tb, mc = transform_methods_to_testbenchios(elem) - tb_list.append(tb) - mc_list.append(mc) - return tb_list, ModuleConnector(*mc_list) - elif isinstance(container, dict): - tb_dict = {} - mc_dict = {} - for name, elem in container.items(): - tb, mc = transform_methods_to_testbenchios(elem) - tb_dict[name] = tb - mc_dict[name] = mc - return tb_dict, ModuleConnector(*mc_dict) - else: - tb = TestbenchIO(AdapterTrans(container)) - return tb, tb - - m = Module() - - m.submodules.dut = self._dut - - for name, attr in vars(self._dut).items(): - if guard_nested_collection(attr, Method) and attr: - tb_cont, mc = transform_methods_to_testbenchios(attr) - self._io[name] = tb_cont - m.submodules[name] = mc - - return m - - def debug_signals(self): - sigs = {"_dut": auto_debug_signals(self._dut)} - for name, io in self._io.items(): - sigs[name] = auto_debug_signals(io) - return sigs - - -class _TestModule(Elaboratable): - def __init__(self, tested_module: HasElaborate, add_transaction_module: bool): - self.tested_module = ( - TransactionModule(tested_module, dependency_manager=DependencyContext.get()) - if add_transaction_module - else tested_module - ) - self.add_transaction_module = add_transaction_module - - def elaborate(self, platform) -> HasElaborate: - m = Module() - - # so that Amaranth allows us to use add_clock - _dummy = Signal() - m.d.sync += _dummy.eq(1) - - m.submodules.tested_module = self.tested_module - - m.domains.sync_neg = ClockDomain(clk_edge="neg", local=True) - - return m - - -class CoreblocksCommand(ABC): - pass - - -class Now(CoreblocksCommand): - pass - - -class SyncProcessWrapper: - def __init__(self, f): - self.org_process = f - self.current_cycle = 0 - - def _wrapping_function(self): - response = None - org_coroutine = self.org_process() - try: - while True: - # call orginal test process and catch data yielded by it in `command` variable - command = org_coroutine.send(response) - # If process wait for new cycle - if command is None or isinstance(command, Tick): - command = command or Tick() - # TODO: use of other domains can mess up the counter! - if command.domain == "sync": - self.current_cycle += 1 - # forward to amaranth - yield command - elif isinstance(command, Now): - response = self.current_cycle - # Pass everything else to amaranth simulator without modifications - else: - response = yield command - except StopIteration: - pass - - -class PysimSimulator(Simulator): - def __init__( - self, - module: HasElaborate, - max_cycles: float = 10e4, - add_transaction_module=True, - traces_file=None, - clk_period=1e-6, - ): - test_module = _TestModule(module, add_transaction_module) - self.tested_module = tested_module = test_module.tested_module - super().__init__(test_module) - - self.add_clock(clk_period) - self.add_clock(clk_period, domain="sync_neg") - - if isinstance(tested_module, HasDebugSignals): - extra_signals = tested_module.debug_signals - else: - extra_signals = functools.partial(auto_debug_signals, tested_module) - - if traces_file: - traces_dir = "test/__traces__" - os.makedirs(traces_dir, exist_ok=True) - # Signal handling is hacky and accesses Simulator internals. - # TODO: try to merge with Amaranth. - if isinstance(extra_signals, Callable): - extra_signals = extra_signals() - clocks = [d.clk for d in cast(Any, self)._design.fragment.domains.values()] - - self.ctx = write_vcd_ext( - cast(Any, self)._engine, - f"{traces_dir}/{traces_file}.vcd", - f"{traces_dir}/{traces_file}.gtkw", - traces=[clocks, extra_signals], - ) - else: - self.ctx = nullcontext() - - self.deadline = clk_period * max_cycles - - def add_process(self, f: Callable[[], TestGen]): - f_wrapped = SyncProcessWrapper(f) - super().add_process(f_wrapped._wrapping_function) - - def run(self) -> bool: - with self.ctx: - self.run_until(self.deadline) - - return not self.advance() - - -class TestCaseWithSimulator: - dependency_manager: DependencyManager - - @contextmanager - def configure_dependency_context(self): - self.dependency_manager = DependencyManager() - with DependencyContext(self.dependency_manager): - yield Tick() - - def add_class_mocks(self, sim: PysimSimulator) -> None: - for key in dir(self): - val = getattr(self, key) - if hasattr(val, "_transactron_testing_process"): - sim.add_process(val) - - def add_local_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: - for key, val in frame_locals.items(): - if hasattr(val, "_transactron_testing_process"): - sim.add_process(val) - - def add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None: - self.add_class_mocks(sim) - self.add_local_mocks(sim, frame_locals) - - def configure_traces(self): - traces_file = None - if "__TRANSACTRON_DUMP_TRACES" in os.environ: - traces_file = self._transactron_current_output_file_name - self._transactron_infrastructure_traces_file = traces_file - - @contextmanager - def configure_profiles(self): - profile = None - if "__TRANSACTRON_PROFILE" in os.environ: - - def f(): - nonlocal profile - try: - transaction_manager = DependencyContext.get().get_dependency(TransactionManagerKey()) - profile = Profile() - return profiler_process(transaction_manager, profile) - except KeyError: - pass - return None - - self._transactron_sim_processes_to_add.append(f) - - yield - - if profile is not None: - profile_dir = "test/__profiles__" - profile_file = self._transactron_current_output_file_name - os.makedirs(profile_dir, exist_ok=True) - profile.encode(f"{profile_dir}/{profile_file}.json") - - @contextmanager - def configure_logging(self): - def on_error(): - assert False, "Simulation finished due to an error" - - log_level = parse_logging_level(os.environ["__TRANSACTRON_LOG_LEVEL"]) - log_filter = os.environ["__TRANSACTRON_LOG_FILTER"] - self._transactron_sim_processes_to_add.append(lambda: make_logging_process(log_level, log_filter, on_error)) - - ch = logging.StreamHandler() - formatter = _LogFormatter() - ch.setFormatter(formatter) - - root_logger = logging.getLogger() - handlers_before = root_logger.handlers.copy() - root_logger.handlers.append(ch) - yield - root_logger.handlers = handlers_before - - @contextmanager - def reinitialize_fixtures(self): - # File name to be used in the current test run (either standard or hypothesis iteration) - # for standard tests it will always have the suffix "_0". For hypothesis tests, it will be suffixed - # with the current hypothesis iteration number, so that each hypothesis run is saved to a - # the different file. - self._transactron_current_output_file_name = ( - self._transactron_base_output_file_name + "_" + str(self._transactron_hypothesis_iter_counter) - ) - self._transactron_sim_processes_to_add: list[Callable[[], Optional[Callable]]] = [] - with self.configure_dependency_context(): - self.configure_traces() - with self.configure_profiles(): - with self.configure_logging(): - yield - self._transactron_hypothesis_iter_counter += 1 - - @pytest.fixture(autouse=True) - def fixture_initialize_testing_env(self, request): - # Hypothesis creates a single instance of a test class, which is later reused multiple times. - # This means that pytest fixtures are only run once. We can take advantage of this behaviour and - # initialise hypothesis related variables. - - # The counter for distinguishing between successive hypothesis iterations, it is incremented - # by `reinitialize_fixtures` which should be started at the beginning of each hypothesis run - self._transactron_hypothesis_iter_counter = 0 - # Base name which will be used later to create file names for particular outputs - self._transactron_base_output_file_name = ".".join(request.node.nodeid.split("/")) - with self.reinitialize_fixtures(): - yield - - @contextmanager - def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True): - clk_period = 1e-6 - sim = PysimSimulator( - module, - max_cycles=max_cycles, - add_transaction_module=add_transaction_module, - traces_file=self._transactron_infrastructure_traces_file, - clk_period=clk_period, - ) - self.add_all_mocks(sim, sys._getframe(2).f_locals) - - yield sim - - for f in self._transactron_sim_processes_to_add: - ret = f() - if ret is not None: - sim.add_process(ret) - - with warnings.catch_warnings(): - # TODO: figure out testing without settles! - warnings.filterwarnings("ignore", r"The `Settle` command is deprecated per RFC 27\.") - - res = sim.run() - assert res, "Simulation time limit exceeded" - - def tick(self, cycle_cnt: int = 1): - """ - Yields for the given number of cycles. - """ - - for _ in range(cycle_cnt): - yield Tick() - - def random_wait(self, max_cycle_cnt: int, *, min_cycle_cnt: int = 0): - """ - Wait for a random amount of cycles in range [min_cycle_cnt, max_cycle_cnt] - """ - yield from self.tick(random.randrange(min_cycle_cnt, max_cycle_cnt + 1)) - - def random_wait_geom(self, prob: float = 0.5): - """ - Wait till the first success, where there is `prob` probability for success in each cycle. - """ - while random.random() > prob: - yield Tick() - - def multi_settle(self, settle_count: int = 1): - for _ in range(settle_count): - yield Settle() diff --git a/transactron/testing/input_generation.py b/transactron/testing/input_generation.py deleted file mode 100644 index 909da7a43..000000000 --- a/transactron/testing/input_generation.py +++ /dev/null @@ -1,97 +0,0 @@ -from amaranth import * -from amaranth.lib.data import StructLayout -from typing import TypeVar -import hypothesis.strategies as st -from hypothesis.strategies import composite, DrawFn, integers, SearchStrategy -from transactron.utils import MethodLayout, RecordIntDict - - -class OpNOP: - def __repr__(self): - return "OpNOP()" - - -T = TypeVar("T") - - -@composite -def generate_shrinkable_list(draw: DrawFn, length: int, generator: SearchStrategy[T]) -> list[T]: - """ - Trick based on https://github.com/HypothesisWorks/hypothesis/blob/ - 6867da71beae0e4ed004b54b92ef7c74d0722815/hypothesis-python/src/hypothesis/stateful.py#L143 - """ - hp_data = draw(st.data()) - lst = [] - if length == 0: - return lst - i = 0 - force_val = None - while True: - b = hp_data.conjecture_data.draw_boolean(p=2**-16, forced=force_val) - if b: - break - lst.append(draw(generator)) - i += 1 - if i == length: - force_val = True - return lst - - -@composite -def generate_based_on_layout(draw: DrawFn, layout: MethodLayout) -> RecordIntDict: - if isinstance(layout, StructLayout): - raise NotImplementedError("StructLayout is not supported in automatic value generation.") - d = {} - for name, sublayout in layout: - if isinstance(sublayout, list): - elem = draw(generate_based_on_layout(sublayout)) - elif isinstance(sublayout, int): - elem = draw(integers(min_value=0, max_value=sublayout)) - elif isinstance(sublayout, range): - elem = draw(integers(min_value=sublayout.start, max_value=sublayout.stop - 1)) - elif isinstance(sublayout, Shape): - if sublayout.signed: - min_value = -(2 ** (sublayout.width - 1)) - max_value = 2 ** (sublayout.width - 1) - 1 - else: - min_value = 0 - max_value = 2**sublayout.width - elem = draw(integers(min_value=min_value, max_value=max_value)) - else: - # Currently type[Enum] and ShapeCastable - raise NotImplementedError("Passed LayoutList with syntax yet unsuported in automatic value generation.") - d[name] = elem - return d - - -def insert_nops(draw: DrawFn, max_nops: int, lst: list): - nops_nr = draw(integers(min_value=0, max_value=max_nops)) - for i in range(nops_nr): - lst.append(OpNOP()) - return lst - - -@composite -def generate_nops_in_list(draw: DrawFn, max_nops: int, generate_list: SearchStrategy[list[T]]) -> list[T | OpNOP]: - lst = draw(generate_list) - out_lst = [] - out_lst = insert_nops(draw, max_nops, out_lst) - for i in lst: - out_lst.append(i) - out_lst = insert_nops(draw, max_nops, out_lst) - return out_lst - - -@composite -def generate_method_input(draw: DrawFn, args: list[tuple[str, MethodLayout]]) -> dict[str, RecordIntDict]: - out = [] - for name, layout in args: - out.append((name, draw(generate_based_on_layout(layout)))) - return dict(out) - - -@composite -def generate_process_input( - draw: DrawFn, elem_count: int, max_nops: int, layouts: list[tuple[str, MethodLayout]] -) -> list[dict[str, RecordIntDict] | OpNOP]: - return draw(generate_nops_in_list(max_nops, generate_shrinkable_list(elem_count, generate_method_input(layouts)))) diff --git a/transactron/testing/logging.py b/transactron/testing/logging.py deleted file mode 100644 index 7c8edf1dc..000000000 --- a/transactron/testing/logging.py +++ /dev/null @@ -1,105 +0,0 @@ -from collections.abc import Callable -from typing import Any -import logging - -from amaranth.sim import Passive, Tick -from transactron.lib import logging as tlog - - -__all__ = ["make_logging_process", "parse_logging_level"] - - -def parse_logging_level(str: str) -> tlog.LogLevel: - """Parse the log level from a string. - - The level can be either a non-negative integer or a string representation - of one of the predefined levels. - - Raises an exception if the level cannot be parsed. - """ - str = str.upper() - names_mapping = logging.getLevelNamesMapping() - if str in names_mapping: - return names_mapping[str] - - # try convert to int - try: - return int(str) - except ValueError: - pass - - raise ValueError("Log level must be either {error, warn, info, debug} or a non-negative integer.") - - -_sim_cycle = 0 - - -class _LogFormatter(logging.Formatter): - """ - Log formatter to provide colors and to inject simulator times into - the log messages. Adapted from https://stackoverflow.com/a/56944256/3638629 - """ - - magenta = "\033[0;35m" - grey = "\033[0;34m" - blue = "\033[0;34m" - yellow = "\033[0;33m" - red = "\033[0;31m" - reset = "\033[0m" - - loglevel2colour = { - logging.DEBUG: grey + "{}" + reset, - logging.INFO: magenta + "{}" + reset, - logging.WARNING: yellow + "{}" + reset, - logging.ERROR: red + "{}" + reset, - } - - def format(self, record: logging.LogRecord): - level_name = self.loglevel2colour[record.levelno].format(record.levelname) - return f"{_sim_cycle} {level_name} {record.name} {record.getMessage()}" - - -def make_logging_process(level: tlog.LogLevel, namespace_regexp: str, on_error: Callable[[], Any]): - combined_trigger = tlog.get_trigger_bit(level, namespace_regexp) - records = tlog.get_log_records(level, namespace_regexp) - - root_logger = logging.getLogger() - - def handle_logs(): - if not (yield combined_trigger): - return - - for record in records: - if not (yield record.trigger): - continue - - values: list[int] = [] - for field in record.fields: - values.append((yield field)) - - formatted_msg = record.format(*values) - - logger = root_logger.getChild(record.logger_name) - logger.log( - record.level, - "[%s:%d] %s", - record.location[0], - record.location[1], - formatted_msg, - ) - - if record.level >= logging.ERROR: - on_error() - - def log_process(): - global _sim_cycle - _sim_cycle = 0 - - yield Passive() - while True: - yield Tick("sync_neg") - yield from handle_logs() - yield Tick() - _sim_cycle += 1 - - return log_process diff --git a/transactron/testing/profiler.py b/transactron/testing/profiler.py deleted file mode 100644 index 795c7f293..000000000 --- a/transactron/testing/profiler.py +++ /dev/null @@ -1,37 +0,0 @@ -from amaranth.sim import * -from transactron.core import TransactionManager -from transactron.core.manager import MethodMap -from transactron.profiler import CycleProfile, MethodSamples, Profile, ProfileData, ProfileSamples, TransactionSamples -from .functions import TestGen - -__all__ = ["profiler_process"] - - -def profiler_process(transaction_manager: TransactionManager, profile: Profile): - def process() -> TestGen: - profile_data, get_id = ProfileData.make(transaction_manager) - method_map = MethodMap(transaction_manager.transactions) - profile.transactions_and_methods = profile_data.transactions_and_methods - - yield Passive() - while True: - yield Tick("sync_neg") - - samples = ProfileSamples() - - for transaction in method_map.transactions: - samples.transactions[get_id(transaction)] = TransactionSamples( - bool((yield transaction.request)), - bool((yield transaction.runnable)), - bool((yield transaction.grant)), - ) - - for method in method_map.methods: - samples.methods[get_id(method)] = MethodSamples(bool((yield method.run))) - - cprof = CycleProfile.make(samples, profile_data) - profile.cycles.append(cprof) - - yield Tick() - - return process diff --git a/transactron/testing/sugar.py b/transactron/testing/sugar.py deleted file mode 100644 index de1dc5e21..000000000 --- a/transactron/testing/sugar.py +++ /dev/null @@ -1,82 +0,0 @@ -import functools -from typing import Callable, Any, Optional -from .testbenchio import TestbenchIO, TestGen -from transactron.utils._typing import RecordIntDict - - -def def_method_mock( - tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], sched_prio: int = 0, **kwargs -) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], TestGen[None]]]: - """ - Decorator function to create method mock handlers. It should be applied on - a function which describes functionality which we want to invoke on method call. - Such function will be wrapped by `method_handle_loop` and called on each - method invocation. - - Function `f` should take only one argument `arg` - data used in function - invocation - and should return data to be sent as response to the method call. - - Function `f` can also be a method and take two arguments `self` and `arg`, - the data to be passed on to invoke a method. It should return data to be sent - as response to the method call. - - Instead of the `arg` argument, the data can be split into keyword arguments. - - Make sure to defer accessing state, since decorators are evaluated eagerly - during function declaration. - - Parameters - ---------- - tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO] - Function to get the TestbenchIO providing appropriate `method_handle_loop`. - **kwargs - Arguments passed to `method_handle_loop`. - - Example - ------- - ``` - m = TestCircuit() - def target_process(k: int): - @def_method_mock(lambda: m.target[k]) - def process(arg): - return {"data": arg["data"] + k} - return process - ``` - or equivalently - ``` - m = TestCircuit() - def target_process(k: int): - @def_method_mock(lambda: m.target[k], settle=1, enable=False) - def process(data): - return {"data": data + k} - return process - ``` - or for class methods - ``` - @def_method_mock(lambda self: self.target[k], settle=1, enable=False) - def process(self, data): - return {"data": data + k} - ``` - """ - - def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[], TestGen[None]]: - @functools.wraps(func) - def mock(func_self=None, /) -> TestGen[None]: - f = func - getter: Any = tb_getter - kw = kwargs - if func_self is not None: - getter = getter.__get__(func_self) - f = f.__get__(func_self) - kw = {} - for k, v in kwargs.items(): - bind = getattr(v, "__get__", None) - kw[k] = bind(func_self) if bind else v - tb = getter() - assert isinstance(tb, TestbenchIO) - yield from tb.method_handle_loop(f, extra_settle_count=sched_prio, **kw) - - mock._transactron_testing_process = 1 # type: ignore - return mock - - return decorator diff --git a/transactron/testing/testbenchio.py b/transactron/testing/testbenchio.py deleted file mode 100644 index 7611a1e6f..000000000 --- a/transactron/testing/testbenchio.py +++ /dev/null @@ -1,148 +0,0 @@ -from amaranth import * -from amaranth.sim import Settle, Passive, Tick -from typing import Optional, Callable -from transactron.lib import AdapterBase -from transactron.lib.adapters import Adapter -from transactron.utils import ValueLike, SignalBundle, mock_def_helper, assign -from transactron.utils._typing import RecordIntDictRet, RecordValueDict, RecordIntDict -from .functions import get_outputs, TestGen - - -class TestbenchIO(Elaboratable): - def __init__(self, adapter: AdapterBase): - self.adapter = adapter - - def elaborate(self, platform): - m = Module() - m.submodules += self.adapter - return m - - # Low-level operations - - def set_enable(self, en) -> TestGen[None]: - yield self.adapter.en.eq(1 if en else 0) - - def enable(self) -> TestGen[None]: - yield from self.set_enable(True) - - def disable(self) -> TestGen[None]: - yield from self.set_enable(False) - - def done(self) -> TestGen[int]: - return (yield self.adapter.done) - - def wait_until_done(self) -> TestGen[None]: - while (yield self.adapter.done) != 1: - yield Tick() - - def set_inputs(self, data: RecordValueDict = {}) -> TestGen[None]: - yield from assign(self.adapter.data_in, data) - - def get_outputs(self) -> TestGen[RecordIntDictRet]: - return (yield from get_outputs(self.adapter.data_out)) - - # Operations for AdapterTrans - - def call_init(self, data: RecordValueDict = {}, /, **kwdata: ValueLike | RecordValueDict) -> TestGen[None]: - if data and kwdata: - raise TypeError("call_init() takes either a single dict or keyword arguments") - if not data: - data = kwdata - yield from self.enable() - yield from self.set_inputs(data) - - def call_result(self) -> TestGen[Optional[RecordIntDictRet]]: - if (yield from self.done()): - return (yield from self.get_outputs()) - return None - - def call_do(self) -> TestGen[RecordIntDict]: - while (outputs := (yield from self.call_result())) is None: - yield Tick() - yield from self.disable() - return outputs - - def call_try( - self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict - ) -> TestGen[Optional[RecordIntDictRet]]: - if data and kwdata: - raise TypeError("call_try() takes either a single dict or keyword arguments") - if not data: - data = kwdata - yield from self.call_init(data) - yield Tick() - outputs = yield from self.call_result() - yield from self.disable() - return outputs - - def call(self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict) -> TestGen[RecordIntDictRet]: - if data and kwdata: - raise TypeError("call() takes either a single dict or keyword arguments") - if not data: - data = kwdata - yield from self.call_init(data) - yield Tick() - return (yield from self.call_do()) - - # Operations for Adapter - - def method_argument(self) -> TestGen[Optional[RecordIntDictRet]]: - return (yield from self.call_result()) - - def method_return(self, data: RecordValueDict = {}) -> TestGen[None]: - yield from self.set_inputs(data) - - def method_handle( - self, - function: Callable[..., Optional[RecordIntDict]], - *, - enable: Optional[Callable[[], bool]] = None, - validate_arguments: Optional[Callable[..., bool]] = None, - extra_settle_count: int = 0, - ) -> TestGen[None]: - enable = enable or (lambda: True) - yield from self.set_enable(enable()) - - def handle_validate_arguments(): - if validate_arguments is not None: - assert isinstance(self.adapter, Adapter) - for a, r in self.adapter.validators: - ret_out = mock_def_helper(self, validate_arguments, (yield from get_outputs(a))) - yield r.eq(ret_out) - for _ in range(extra_settle_count + 1): - yield Settle() - - # One extra Settle() required to propagate enable signal. - for _ in range(extra_settle_count + 1): - yield Settle() - yield from handle_validate_arguments() - while (arg := (yield from self.method_argument())) is None: - yield Tick() - - yield from self.set_enable(enable()) - for _ in range(extra_settle_count + 1): - yield Settle() - yield from handle_validate_arguments() - - ret_out = mock_def_helper(self, function, arg) - yield from self.method_return(ret_out or {}) - yield Tick() - - def method_handle_loop( - self, - function: Callable[..., Optional[RecordIntDict]], - *, - enable: Optional[Callable[[], bool]] = None, - validate_arguments: Optional[Callable[..., bool]] = None, - extra_settle_count: int = 0, - ) -> TestGen[None]: - yield Passive() - while True: - yield from self.method_handle( - function, enable=enable, validate_arguments=validate_arguments, extra_settle_count=extra_settle_count - ) - - # Debug signals - - def debug_signals(self) -> SignalBundle: - return self.adapter.debug_signals() diff --git a/transactron/tracing.py b/transactron/tracing.py deleted file mode 100644 index 6f9c709f1..000000000 --- a/transactron/tracing.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -Utilities for extracting dependencies from Amaranth. -""" - -import warnings - -from amaranth.hdl import Elaboratable, Fragment, Instance -from amaranth.hdl._xfrm import FragmentTransformer -from amaranth.hdl import _dsl, _ir, _mem, _xfrm -from amaranth.lib import memory # type: ignore -from amaranth_types import SrcLoc -from transactron.utils import HasElaborate -from . import core - - -# generic tuple because of aggressive monkey-patching -modules_with_fragment: tuple = core, _ir, _dsl, _mem, _xfrm -# List of Fragment subclasses which should be patched to inherit from TracingFragment. -# The first element of the tuple is a subclass name to patch, and the second element -# of the tuple is tuple with modules in which the patched subclass should be installed. -fragment_subclasses_to_patch = [("MemoryInstance", (memory, _mem, _xfrm))] - -DIAGNOSTICS = False -orig_on_fragment = FragmentTransformer.on_fragment - - -class TracingEnabler: - def __enter__(self): - self.orig_fragment_get = Fragment.get - self.orig_on_fragment = FragmentTransformer.on_fragment - self.orig_fragment_class = _ir.Fragment - self.orig_instance_class = _ir.Instance - self.orig_patched_fragment_subclasses = [] - Fragment.get = TracingFragment.get - FragmentTransformer.on_fragment = TracingFragmentTransformer.on_fragment - for mod in modules_with_fragment: - mod.Fragment = TracingFragment - mod.Instance = TracingInstance - for class_name, modules in fragment_subclasses_to_patch: - orig_fragment_subclass = getattr(modules[0], class_name) - # `type` is used to declare new class dynamicaly. There is passed `orig_fragment_subclass` as a first - # base class to allow `super()` to work. Calls to `super` without arguments are syntax sugar and are - # extended on compile/interpretation (not execution!) phase to the `super(OriginalClass, self)`, - # so they are hardcoded on execution time to look for the original class - # (see: https://docs.python.org/3/library/functions.html#super). - # This cause that OriginalClass has to be in `__mro__` of the newly created class, because else an - # TypeError will be raised (see: https://stackoverflow.com/a/40819403). Adding OriginalClass to the - # bases of patched class allows us to fix the TypeError. Everything works correctly because `super` - # starts search of `__mro__` from the class right after the first argument. In our case the first - # checked class will be `TracingFragment` as we want. - newclass = type( - class_name, - ( - orig_fragment_subclass, - TracingFragment, - ), - dict(orig_fragment_subclass.__dict__), - ) - for mod in modules: - setattr(mod, class_name, newclass) - self.orig_patched_fragment_subclasses.append((class_name, orig_fragment_subclass, modules)) - - def __exit__(self, tp, val, tb): - Fragment.get = self.orig_fragment_get - FragmentTransformer.on_fragment = self.orig_on_fragment - for mod in modules_with_fragment: - mod.Fragment = self.orig_fragment_class - mod.Instance = self.orig_instance_class - for class_name, orig_fragment_subclass, modules in self.orig_patched_fragment_subclasses: - for mod in modules: - setattr(mod, class_name, orig_fragment_subclass) - - -class TracingFragmentTransformer(FragmentTransformer): - def on_fragment(self: FragmentTransformer, fragment): - ret = orig_on_fragment(self, fragment) - ret._tracing_original = fragment - fragment._elaborated = ret - return ret - - -class TracingFragment(Fragment): - _tracing_original: Elaboratable - subfragments: list[tuple[Elaboratable, str, SrcLoc]] - - if DIAGNOSTICS: - - def __init__(self, *args, **kwargs): - import sys - import traceback - - self.created = traceback.format_stack(sys._getframe(1)) - super().__init__(*args, **kwargs) - - def __del__(self): - if not hasattr(self, "_tracing_original"): - print("Missing tracing hook:") - for line in self.created: - print(line, end="") - - @staticmethod - def get(obj: HasElaborate, platform) -> "TracingFragment": - """ - This function code is based on Amaranth, which originally loses all information. - It was too difficult to hook into, so this has to be a near-exact copy. - - Relevant copyrights apply. - """ - with TracingEnabler(): - code = None - old_obj = None - while True: - if isinstance(obj, TracingFragment): - return obj - elif isinstance(obj, Fragment): - raise NotImplementedError(f"Monkey-patching missed some Fragment in {old_obj}.elaborate()?") - # This is literally taken from Amaranth {{ - elif isinstance(obj, Elaboratable): - code = obj.elaborate.__code__ - obj._MustUse__used = True # type: ignore - new_obj = obj.elaborate(platform) - elif hasattr(obj, "elaborate"): - warnings.warn( - message="Class {!r} is an elaboratable that does not explicitly inherit from " - "Elaboratable; doing so would improve diagnostics".format(type(obj)), - category=RuntimeWarning, - stacklevel=2, - ) - code = obj.elaborate.__code__ - new_obj = obj.elaborate(platform) - else: - raise AttributeError("Object {!r} cannot be elaborated".format(obj)) - if new_obj is obj: - raise RecursionError("Object {!r} elaborates to itself".format(obj)) - if new_obj is None and code is not None: - warnings.warn_explicit( - message=".elaborate() returned None; missing return statement?", - category=UserWarning, - filename=code.co_filename, - lineno=code.co_firstlineno, - ) - # }} (taken from Amaranth) - new_obj._tracing_original = obj # type: ignore - obj._elaborated = new_obj # type: ignore - - old_obj = obj - obj = new_obj - - def prepare(self, *args, **kwargs) -> "TracingFragment": - with TracingEnabler(): - ret = super().prepare(*args, **kwargs) - ret._tracing_original = self - self._elaborated = ret - return ret - - -class TracingInstance(Instance, TracingFragment): - _tracing_original: Elaboratable - get = TracingFragment.get diff --git a/transactron/utils/__init__.py b/transactron/utils/__init__.py deleted file mode 100644 index ebf845b7d..000000000 --- a/transactron/utils/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .data_repr import * # noqa: F401 -from ._typing import * # noqa: F401 -from .debug_signals import * # noqa: F401 -from .assign import * # noqa: F401 -from .amaranth_ext import * # noqa: F401 -from .transactron_helpers import * # noqa: F401 -from .dependencies import * # noqa: F401 -from .depcache import * # noqa: F401 -from .idgen import * # noqa: F401 diff --git a/transactron/utils/_typing.py b/transactron/utils/_typing.py deleted file mode 100644 index 1a264527b..000000000 --- a/transactron/utils/_typing.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import ( - Callable, - Concatenate, - ParamSpec, - Protocol, - TypeAlias, - TypeVar, - cast, - runtime_checkable, - Union, - Any, -) -from collections.abc import Iterable, Mapping -from amaranth import * -from amaranth.lib.data import StructLayout, View -from amaranth_types import * -from amaranth_types import _ModuleBuilderDomainsLike - -__all__ = [ - "FragmentLike", - "ValueLike", - "ShapeLike", - "StatementLike", - "SwitchKey", - "SrcLoc", - "MethodLayout", - "MethodStruct", - "SignalBundle", - "LayoutListField", - "LayoutList", - "LayoutIterable", - "RecordIntDict", - "RecordIntDictRet", - "RecordValueDict", - "RecordDict", - "ROGraph", - "Graph", - "GraphCC", - "_ModuleBuilderDomainsLike", - "ModuleLike", - "HasElaborate", - "HasDebugSignals", -] - -# Internal Coreblocks types -SignalBundle: TypeAlias = Signal | Record | View | Iterable["SignalBundle"] | Mapping[str, "SignalBundle"] -LayoutListField: TypeAlias = tuple[str, "ShapeLike | LayoutList"] -LayoutList: TypeAlias = list[LayoutListField] -LayoutIterable: TypeAlias = Iterable[LayoutListField] -MethodLayout: TypeAlias = StructLayout | LayoutIterable -MethodStruct: TypeAlias = "View[StructLayout]" - -RecordIntDict: TypeAlias = Mapping[str, Union[int, "RecordIntDict"]] -RecordIntDictRet: TypeAlias = Mapping[str, Any] # full typing hard to work with -RecordValueDict: TypeAlias = Mapping[str, Union[ValueLike, "RecordValueDict"]] -RecordDict: TypeAlias = ValueLike | Mapping[str, "RecordDict"] - -T = TypeVar("T") -U = TypeVar("U") -P = ParamSpec("P") - -ROGraph: TypeAlias = Mapping[T, Iterable[T]] -Graph: TypeAlias = dict[T, set[T]] -GraphCC: TypeAlias = set[T] - - -@runtime_checkable -class HasDebugSignals(Protocol): - def debug_signals(self) -> SignalBundle: ... - - -def type_self_kwargs_as(as_func: Callable[Concatenate[Any, P], Any]): - """ - Decorator used to annotate `**kwargs` type to be the same as named arguments from `as_func` method. - - Works only with methods with (self, **kwargs) signature. `self` parameter is also required in `as_func`. - """ - - def return_func(func: Callable[Concatenate[Any, ...], T]) -> Callable[Concatenate[Any, P], T]: - return cast(Callable[Concatenate[Any, P], T], func) - - return return_func diff --git a/transactron/utils/amaranth_ext/__init__.py b/transactron/utils/amaranth_ext/__init__.py deleted file mode 100644 index 2b8533b12..000000000 --- a/transactron/utils/amaranth_ext/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .functions import * # noqa: F401 -from .elaboratables import * # noqa: F401 diff --git a/transactron/utils/amaranth_ext/elaboratables.py b/transactron/utils/amaranth_ext/elaboratables.py deleted file mode 100644 index ed6b57122..000000000 --- a/transactron/utils/amaranth_ext/elaboratables.py +++ /dev/null @@ -1,532 +0,0 @@ -import itertools -from contextlib import contextmanager -from typing import Literal, Optional, overload -from collections.abc import Iterable -from amaranth import * -from transactron.utils._typing import HasElaborate, ModuleLike, ValueLike - -__all__ = [ - "OneHotSwitchDynamic", - "OneHotSwitch", - "ModuleConnector", - "Scheduler", - "RoundRobin", - "MultiPriorityEncoder", - "RingMultiPriorityEncoder", -] - - -@contextmanager -def OneHotSwitch(m: ModuleLike, test: Value): - """One-hot switch. - - This function allows one-hot matching in the style similar to the standard - Amaranth `Switch`. This allows to get the performance benefit of using - the one-hot representation. - - Example:: - - with OneHotSwitch(m, sig) as OneHotCase: - with OneHotCase(0b01): - ... - with OneHotCase(0b10): - ... - # optional default case - with OneHotCase(): - ... - - Parameters - ---------- - m : Module - The module for which the matching is defined. - test : Signal - The signal being tested. - """ - - @contextmanager - def case(n: Optional[int] = None): - if n is None: - with m.Default(): - yield - else: - # find the index of the least significant bit set - i = (n & -n).bit_length() - 1 - if n - (1 << i) != 0: - raise ValueError("%d not in one-hot representation" % n) - with m.Case(n): - yield - - with m.Switch(test): - yield case - - -@overload -def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: Literal[True]) -> Iterable[Optional[int]]: ... - - -@overload -def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: Literal[False] = False) -> Iterable[int]: ... - - -def OneHotSwitchDynamic(m: ModuleLike, test: Value, *, default: bool = False) -> Iterable[Optional[int]]: - """Dynamic one-hot switch. - - This function allows simple one-hot matching on signals which can have - variable bit widths. - - Example:: - - for i in OneHotSwitchDynamic(m, sig): - # code dependent on the bit index i - ... - - Parameters - ---------- - m : Module - The module for which the matching is defined. - test : Signal - The signal being tested. - default : bool, optional - Whether the matching includes a default case (signified by a None). - """ - count = len(test) - with OneHotSwitch(m, test) as OneHotCase: - for i in range(count): - with OneHotCase(1 << i): - yield i - if default: - with OneHotCase(): - yield None - return - - -class ModuleConnector(Elaboratable): - """ - An Elaboratable to create a new module, which will have all arguments - added as its submodules. - """ - - def __init__(self, *args: HasElaborate, **kwargs: HasElaborate): - """ - Parameters - ---------- - *args - Modules which should be added as anonymous submodules. - **kwargs - Modules which will be added as named submodules. - """ - self.args = args - self.kwargs = kwargs - - def elaborate(self, platform): - m = Module() - - for elem in self.args: - m.submodules += elem - - for name, elem in self.kwargs.items(): - m.submodules[name] = elem - - return m - - -class Scheduler(Elaboratable): - """Scheduler - - An implementation of a round-robin scheduler, which is used in the - transaction subsystem. It is based on Amaranth's round-robin scheduler - but instead of using binary numbers, it uses one-hot encoding for the - `grant` output signal. - - Attributes - ---------- - requests: Signal(count), in - Signals that something (e.g. a transaction) wants to run. When i-th - bit is high, then the i-th agent requests the grant signal. - grant: Signal(count), out - Signals that something (e.g. transaction) is granted to run. It uses - one-hot encoding. - valid : Signal(1), out - Signal that `grant` signals are valid. - """ - - def __init__(self, count: int): - """ - Parameters - ---------- - count : int - Number of agents between which the scheduler should arbitrate. - """ - if not isinstance(count, int) or count < 0: - raise ValueError("Count must be a non-negative integer, not {!r}".format(count)) - self.count = count - - self.requests = Signal(count) - self.grant = Signal(count, init=1) - self.valid = Signal() - - def elaborate(self, platform): - m = Module() - - grant_reg = Signal.like(self.grant) - - for i in OneHotSwitchDynamic(m, grant_reg, default=True): - if i is not None: - m.d.comb += self.grant.eq(grant_reg) - for j in itertools.chain(reversed(range(i)), reversed(range(i + 1, self.count))): - with m.If(self.requests[j]): - m.d.comb += self.grant.eq(1 << j) - else: - m.d.comb += self.grant.eq(0) - - m.d.comb += self.valid.eq(self.requests.any()) - - m.d.sync += grant_reg.eq(self.grant) - - return m - - -class RoundRobin(Elaboratable): - """Round-robin scheduler. - For a given set of requests, the round-robin scheduler will - grant one request. Once it grants a request, if any other - requests are active, it grants the next active request with - a greater number, restarting from zero once it reaches the - highest one. - Use :class:`EnableInserter` to control when the scheduler - is updated. - - Implementation ported from amaranth lib. - - Parameters - ---------- - count : int - Number of requests. - Attributes - ---------- - requests : Signal(count), in - Set of requests. - grant : Signal(range(count)), out - Number of the granted request. Does not change if there are no - active requests. - valid : Signal(), out - Asserted if grant corresponds to an active request. Deasserted - otherwise, i.e. if no requests are active. - """ - - def __init__(self, *, count): - if not isinstance(count, int) or count < 0: - raise ValueError("Count must be a non-negative integer, not {!r}".format(count)) - self.count = count - - self.requests = Signal(count) - self.grant = Signal(range(count)) - self.valid = Signal() - - def elaborate(self, platform): - m = Module() - - with m.Switch(self.grant): - for i in range(self.count): - with m.Case(i): - for pred in reversed(range(i)): - with m.If(self.requests[pred]): - m.d.sync += self.grant.eq(pred) - for succ in reversed(range(i + 1, self.count)): - with m.If(self.requests[succ]): - m.d.sync += self.grant.eq(succ) - - m.d.sync += self.valid.eq(self.requests.any()) - - return m - - -class MultiPriorityEncoder(Elaboratable): - """Priority encoder with more outputs - - This is an extension of the `PriorityEncoder` from amaranth that supports - more than one output from an input signal. In other words - it decodes multi-hot encoded signal into lists of signals in binary - format, each with the index of a different high bit in the input. - - Attributes - ---------- - input_width : int - Width of the input signal - outputs_count : int - Number of outputs to generate at once. - input : Signal, in - Signal with 1 on `i`-th bit if `i` can be selected by encoder - outputs : list[Signal], out - Signals with selected indicies, sorted in ascending order, - if the number of ready signals is less than `outputs_count` - then valid signals are at the beginning of the list. - valids : list[Signal], out - One bit for each output signal, indicating whether the output is valid or not. - """ - - def __init__(self, input_width: int, outputs_count: int): - self.input_width = input_width - self.outputs_count = outputs_count - - self.input = Signal(self.input_width) - self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)] - self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)] - - @staticmethod - def create( - m: Module, input_width: int, input: ValueLike, outputs_count: int = 1, name: Optional[str] = None - ) -> list[tuple[Signal, Signal]]: - """Syntax sugar for creating MultiPriorityEncoder - - This static method allows to use MultiPriorityEncoder in a more functional - way. Instead of creating the instance manually, connecting all the signals and - adding a submodule, you can call this function to do it automatically. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - m.submodules += prio_encoder = PriorityEncoder(cnt) - m.d.top_comb += prio_encoder.input.eq(one_hot_singal) - idx = prio_encoder.outputs - valid = prio.encoder.valids - - Parameters - ---------- - m: Module - Module to add the MultiPriorityEncoder to. - input_width : int - Width of the one hot signal. - input : ValueLike - The one hot signal to decode. - outputs_count : int - Number of different decoder outputs to generate at once. Default: 1. - name : Optional[str] - Name to use when adding MultiPriorityEncoder to submodules. - If None, it will be added as an anonymous submodule. The given name - can not be used in a submodule that has already been added. Default: None. - - Returns - ------- - return : list[tuple[Signal, Signal]] - Returns a list with len equal to outputs_count. Each tuple contains - a pair of decoded index on the first position and a valid signal - on the second position. - """ - prio_encoder = MultiPriorityEncoder(input_width, outputs_count) - if name is None: - m.submodules += prio_encoder - else: - try: - getattr(m.submodules, name) - raise ValueError(f"Name: {name} is already in use, so MultiPriorityEncoder can not be added with it.") - except AttributeError: - setattr(m.submodules, name, prio_encoder) - m.d.comb += prio_encoder.input.eq(input) - return list(zip(prio_encoder.outputs, prio_encoder.valids)) - - @staticmethod - def create_simple( - m: Module, input_width: int, input: ValueLike, name: Optional[str] = None - ) -> tuple[Signal, Signal]: - """Syntax sugar for creating MultiPriorityEncoder - - This is the same as `create` function, but with `outputs_count` hardcoded to 1. - """ - lst = MultiPriorityEncoder.create(m, input_width, input, outputs_count=1, name=name) - return lst[0] - - def build_tree(self, m: Module, in_sig: Signal, start_idx: int): - assert len(in_sig) > 0 - level_outputs = [ - Signal(range(self.input_width), name=f"_lvl_out_idx{start_idx}_{i}") for i in range(self.outputs_count) - ] - level_valids = [Signal(name=f"_lvl_val_idx{start_idx}_{i}") for i in range(self.outputs_count)] - if len(in_sig) == 1: - with m.If(in_sig): - m.d.comb += level_outputs[0].eq(start_idx) - m.d.comb += level_valids[0].eq(1) - else: - middle = len(in_sig) // 2 - r_in = Signal(middle, name=f"_r_in_idx{start_idx}") - l_in = Signal(len(in_sig) - middle, name=f"_l_in_idx{start_idx}") - m.d.comb += r_in.eq(in_sig[0:middle]) - m.d.comb += l_in.eq(in_sig[middle:]) - r_out, r_val = self.build_tree(m, r_in, start_idx) - l_out, l_val = self.build_tree(m, l_in, start_idx + middle) - - with m.Switch(Cat(r_val)): - for i in range(self.outputs_count + 1): - with m.Case((1 << i) - 1): - for j in range(i): - m.d.comb += level_outputs[j].eq(r_out[j]) - m.d.comb += level_valids[j].eq(r_val[j]) - for j in range(i, self.outputs_count): - m.d.comb += level_outputs[j].eq(l_out[j - i]) - m.d.comb += level_valids[j].eq(l_val[j - i]) - return level_outputs, level_valids - - def elaborate(self, platform): - m = Module() - - level_outputs, level_valids = self.build_tree(m, self.input, 0) - - for k in range(self.outputs_count): - m.d.comb += self.outputs[k].eq(level_outputs[k]) - m.d.comb += self.valids[k].eq(level_valids[k]) - - return m - - -class RingMultiPriorityEncoder(Elaboratable): - """Priority encoder with one or more outputs and flexible start - - This is an extension of the `MultiPriorityEncoder` that supports - flexible start and end indexes. In the standard `MultiPriorityEncoder` - the first bit is always at position 0 and the last is the last bit of - the input signal. In this extended implementation, both can be - selected at runtime. - - This implementation is intended for selection from the circular buffers, - so if `last < first` the encoder will first select bits from - [first, input_width) and then from [0, last). - - Attributes - ---------- - input_width : int - Width of the input signal - outputs_count : int - Number of outputs to generate at once. - input : Signal, in - Signal with 1 on `i`-th bit if `i` can be selected by encoder - first : Signal, in - Index of the first bit in the `input`. Inclusive. - last : Signal, out - Index of the last bit in the `input`. Exclusive. - outputs : list[Signal], out - Signals with selected indicies, sorted in ascending order, - if the number of ready signals is less than `outputs_count` - then valid signals are at the beginning of the list. - valids : list[Signal], out - One bit for each output signal, indicating whether the output is valid or not. - """ - - def __init__(self, input_width: int, outputs_count: int): - self.input_width = input_width - self.outputs_count = outputs_count - - self.input = Signal(self.input_width) - self.first = Signal(range(self.input_width)) - self.last = Signal(range(self.input_width)) - self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)] - self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)] - - @staticmethod - def create( - m: Module, - input_width: int, - input: ValueLike, - first: ValueLike, - last: ValueLike, - outputs_count: int = 1, - name: Optional[str] = None, - ) -> list[tuple[Signal, Signal]]: - """Syntax sugar for creating RingMultiPriorityEncoder - - This static method allows to use RingMultiPriorityEncoder in a more functional - way. Instead of creating the instance manually, connecting all the signals and - adding a submodule, you can call this function to do it automatically. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - m.submodules += prio_encoder = RingMultiPriorityEncoder(input_width, outputs_count) - m.d.comb += prio_encoder.input.eq(one_hot_singal) - m.d.comb += prio_encoder.first.eq(first) - m.d.comb += prio_encoder.last.eq(last) - idx = prio_encoder.outputs - valid = prio.encoder.valids - - Parameters - ---------- - m: Module - Module to add the RingMultiPriorityEncoder to. - input_width : int - Width of the one hot signal. - input : ValueLike - The one hot signal to decode. - first : ValueLike - Index of the first bit in the `input`. Inclusive. - last : ValueLike - Index of the last bit in the `input`. Exclusive. - outputs_count : int - Number of different decoder outputs to generate at once. Default: 1. - name : Optional[str] - Name to use when adding RingMultiPriorityEncoder to submodules. - If None, it will be added as an anonymous submodule. The given name - can not be used in a submodule that has already been added. Default: None. - - Returns - ------- - return : list[tuple[Signal, Signal]] - Returns a list with len equal to outputs_count. Each tuple contains - a pair of decoded index on the first position and a valid signal - on the second position. - """ - prio_encoder = RingMultiPriorityEncoder(input_width, outputs_count) - if name is None: - m.submodules += prio_encoder - else: - try: - getattr(m.submodules, name) - raise ValueError( - f"Name: {name} is already in use, so RingMultiPriorityEncoder can not be added with it." - ) - except AttributeError: - setattr(m.submodules, name, prio_encoder) - m.d.comb += prio_encoder.input.eq(input) - m.d.comb += prio_encoder.first.eq(first) - m.d.comb += prio_encoder.last.eq(last) - return list(zip(prio_encoder.outputs, prio_encoder.valids)) - - @staticmethod - def create_simple( - m: Module, input_width: int, input: ValueLike, first: ValueLike, last: ValueLike, name: Optional[str] = None - ) -> tuple[Signal, Signal]: - """Syntax sugar for creating RingMultiPriorityEncoder - - This is the same as `create` function, but with `outputs_count` hardcoded to 1. - """ - lst = RingMultiPriorityEncoder.create(m, input_width, input, first, last, outputs_count=1, name=name) - return lst[0] - - def elaborate(self, platform): - m = Module() - double_input = Signal(2 * self.input_width) - m.d.comb += double_input.eq(Cat(self.input, self.input)) - - last_corrected = Signal(range(self.input_width * 2)) - with m.If(self.first > self.last): - m.d.comb += last_corrected.eq(self.input_width + self.last) - with m.Else(): - m.d.comb += last_corrected.eq(self.last) - - mask = Signal.like(double_input) - m.d.comb += mask.eq((1 << last_corrected) - 1) - - multi_enc_input = (double_input & mask) >> self.first - - m.submodules.multi_enc = multi_enc = MultiPriorityEncoder(self.input_width, self.outputs_count) - m.d.comb += multi_enc.input.eq(multi_enc_input) - for k in range(self.outputs_count): - moved_out = Signal(range(2 * self.input_width)) - m.d.comb += moved_out.eq(multi_enc.outputs[k] + self.first) - corrected_out = Mux(moved_out >= self.input_width, moved_out - self.input_width, moved_out) - - m.d.comb += self.outputs[k].eq(corrected_out) - m.d.comb += self.valids[k].eq(multi_enc.valids[k]) - return m diff --git a/transactron/utils/amaranth_ext/functions.py b/transactron/utils/amaranth_ext/functions.py deleted file mode 100644 index d09c7b53b..000000000 --- a/transactron/utils/amaranth_ext/functions.py +++ /dev/null @@ -1,99 +0,0 @@ -from amaranth import * -from amaranth.utils import bits_for, exact_log2 -from amaranth.lib import data -from collections.abc import Iterable, Mapping -from transactron.utils._typing import SignalBundle - -__all__ = [ - "mod_incr", - "popcount", - "count_leading_zeros", - "count_trailing_zeros", - "flatten_signals", -] - - -def mod_incr(sig: Value, mod: int) -> Value: - """ - Perform `(sig+1) % mod` operation. - """ - if mod == 2 ** len(sig): - return sig + 1 - return Mux(sig == mod - 1, 0, sig + 1) - - -def popcount(s: Value): - sum_layers = [s[i] for i in range(len(s))] - - while len(sum_layers) > 1: - if len(sum_layers) % 2: - sum_layers.append(C(0)) - sum_layers = [a + b for a, b in zip(sum_layers[::2], sum_layers[1::2])] - - return sum_layers[0][0 : bits_for(len(s))] - - -def count_leading_zeros(s: Value) -> Value: - def iter(s: Value, step: int) -> Value: - # if no bits left - return empty value - if step == 0: - return C(0) - - # boudaries of upper and lower halfs of the value - partition = 2 ** (step - 1) - current_bit = 1 << (step - 1) - - # recursive call - upper_value = iter(s[partition:], step - 1) - lower_value = iter(s[:partition], step - 1) - - # if there are lit bits in upperhalf - take result directly from recursive value - # otherwise add 1 << (step - 1) to lower value and return - result = Mux(s[partition:].any(), upper_value, lower_value | current_bit) - - return result - - try: - xlen_log = exact_log2(len(s)) - except ValueError: - raise NotImplementedError("CountLeadingZeros - only sizes aligned to power of 2 are supperted") - - value = iter(s, xlen_log) - - # 0 number edge case - # if s == 0 then iter() returns value off by 1 - # this switch negates this effect - high_bit = 1 << xlen_log - - result = Mux(s.any(), value, high_bit) - return result - - -def count_trailing_zeros(s: Value) -> Value: - try: - exact_log2(len(s)) - except ValueError: - raise NotImplementedError("CountTrailingZeros - only sizes aligned to power of 2 are supperted") - - return count_leading_zeros(s[::-1]) - - -def flatten_signals(signals: SignalBundle) -> Iterable[Signal]: - """ - Flattens input data, which can be either a signal, a record, a list (or a dict) of SignalBundle items. - - """ - if isinstance(signals, Mapping): - for x in signals.values(): - yield from flatten_signals(x) - elif isinstance(signals, Iterable): - for x in signals: - yield from flatten_signals(x) - elif isinstance(signals, Record): - for x in signals.fields.values(): - yield from flatten_signals(x) - elif isinstance(signals, data.View): - for x, _ in signals.shape(): - yield from flatten_signals(signals[x]) - else: - yield signals diff --git a/transactron/utils/assign.py b/transactron/utils/assign.py deleted file mode 100644 index 4257d2df6..000000000 --- a/transactron/utils/assign.py +++ /dev/null @@ -1,227 +0,0 @@ -from enum import Enum -from typing import Optional, TypeAlias, cast, TYPE_CHECKING -from collections.abc import Sequence, Iterable, Mapping -from amaranth import * -from amaranth.hdl import ShapeLike, ValueCastable -from amaranth.hdl._ast import ArrayProxy, Slice -from amaranth.lib import data -from ._typing import ValueLike - -if TYPE_CHECKING: - from amaranth.hdl._ast import Assign - -__all__ = [ - "AssignType", - "assign", -] - - -class AssignType(Enum): - COMMON = 1 - LHS = 2 - RHS = 3 - ALL = 4 - - -AssignFields: TypeAlias = AssignType | Iterable[str | int] | Mapping[str | int, "AssignFields"] -AssignArg: TypeAlias = ValueLike | Mapping[str, "AssignArg"] | Mapping[int, "AssignArg"] | Sequence["AssignArg"] - - -def arrayproxy_fields(proxy: ArrayProxy) -> Optional[set[str | int]]: - def flatten_elems(proxy: ArrayProxy): - for elem in proxy.elems: - if isinstance(elem, ArrayProxy): - yield from flatten_elems(elem) - else: - yield elem - - elems = list(flatten_elems(proxy)) - if elems and all(isinstance(el, data.View) for el in elems): - return set.intersection(*[set(cast(data.View, el).shape().members.keys()) for el in elems]) - - -def assign_arg_fields(val: AssignArg) -> Optional[set[str | int]]: - if isinstance(val, ArrayProxy): - return arrayproxy_fields(val) - elif isinstance(val, data.View): - layout = val.shape() - if isinstance(layout, data.StructLayout): - return set(k for k in layout.members) - if isinstance(layout, data.ArrayLayout): - return set(range(layout.length)) - elif isinstance(val, dict): - return set(val.keys()) - elif isinstance(val, list): - return set(range(len(val))) - - -def valuelike_shape(val: ValueLike) -> ShapeLike: - if isinstance(val, Value) or isinstance(val, ValueCastable): - return val.shape() - else: - return Value.cast(val).shape() - - -def is_union(val: AssignArg): - return isinstance(val, data.View) and isinstance(val.shape(), data.UnionLayout) - - -def assign( - lhs: AssignArg, rhs: AssignArg, *, fields: AssignFields = AssignType.RHS, lhs_strict=False, rhs_strict=False -) -> Iterable["Assign"]: - """Safe structured assignment. - - This function recursively generates assignment statements for - field-containing structures. This includes: - Amaranth `View`\\s using `StructLayout`, Python `dict`\\s. In case of - mismatching fields or bit widths, error is raised. - - When both `lhs` and `rhs` are field-containing, `assign` generates - assignment statements according to the value of the `field` parameter. - If either of `lhs` or `rhs` is not field-containing, `assign` checks for - the same bit width and generates a single assignment statement. - - The bit width check is performed if: - - - Any of `lhs` or `rhs` is a `View`. - - Both `lhs` and `rhs` have an explicitly defined shape (e.g. are a - `Signal`, a field of a `View`). - - Parameters - ---------- - lhs : View or Value-castable or dict - View, signal or dict being assigned. - rhs : View or Value-castable or dict - View, signal or dict containing assigned values. - fields : AssignType or Iterable or Mapping, optional - Determines which fields will be assigned. Possible values: - - AssignType.COMMON - Only fields common to `lhs` and `rhs` are assigned. - AssignType.LHS - All fields in `lhs` are assigned. If one of them is not present - in `rhs`, an exception is raised. - AssignType.RHS - All fields in `rhs` are assigned. If one of them is not present - in `lhs`, an exception is raised. - AssignType.ALL - Assume that both structures have the same layouts. All fields present - in `lhs` or `rhs` are assigned. - Mapping - Keys are field names, values follow the format for `fields`. - Iterable - Items are field names. For subfields, AssignType.ALL is assumed. - - Returns - ------- - Iterable[Assign] - Generated assignment statements. - - Raises - ------ - ValueError - If the assignment can't be safely performed. - """ - lhs_fields = assign_arg_fields(lhs) - rhs_fields = assign_arg_fields(rhs) - - def rec_call(name: str | int): - subfields = fields - if isinstance(fields, Mapping): - subfields = fields[name] - elif isinstance(fields, Iterable): - subfields = AssignType.ALL - - return assign( - lhs[name], # type: ignore - rhs[name], # type: ignore - fields=subfields, - lhs_strict=isinstance(lhs, ValueLike), - rhs_strict=isinstance(rhs, ValueLike), - ) - - if lhs_fields is not None and rhs_fields is not None: - # asserts for type checking - assert ( - isinstance(lhs, ArrayProxy) - or isinstance(lhs, Mapping) - or isinstance(lhs, Sequence) - or isinstance(lhs, data.View) - ) - assert ( - isinstance(rhs, ArrayProxy) - or isinstance(rhs, Mapping) - or isinstance(rhs, Sequence) - or isinstance(rhs, data.View) - ) - - if fields is AssignType.COMMON: - names = lhs_fields & rhs_fields - elif fields is AssignType.LHS: - names = lhs_fields - elif fields is AssignType.RHS: - names = rhs_fields - elif fields is AssignType.ALL: - names = lhs_fields | rhs_fields - else: - names = set(fields) - - if not names and (lhs_fields or rhs_fields): - raise ValueError("There are no common fields in assigment lhs: {} rhs: {}".format(lhs_fields, rhs_fields)) - - for name in names: - if name not in lhs_fields: - raise KeyError("Field {} not present in lhs".format(name)) - if name not in rhs_fields: - raise KeyError("Field {} not present in rhs".format(name)) - - yield from rec_call(name) - elif is_union(lhs) and isinstance(rhs, Mapping) or isinstance(lhs, Mapping) and is_union(rhs): - mapping, union = (lhs, rhs) if isinstance(lhs, Mapping) else (rhs, lhs) - - # asserts for type checking - assert isinstance(mapping, Mapping) - assert isinstance(union, data.View) - - if len(mapping) != 1: - raise ValueError(f"Non-singleton mapping on union assignment lhs: {lhs} rhs: {rhs}") - name = next(iter(mapping)) - - if name not in union.shape().members: - raise ValueError(f"Field {name} not present in union {union}") - - yield from rec_call(name) - else: - if not isinstance(fields, AssignType): - raise ValueError("Fields on assigning non-structures lhs: {} rhs: {}".format(lhs, rhs)) - if not isinstance(lhs, ValueLike) or not isinstance(rhs, ValueLike): - raise TypeError("Unsupported assignment lhs: {} rhs: {}".format(lhs, rhs)) - - # If a single-value structure, assign its only field - while lhs_fields is not None and len(lhs_fields) == 1: - lhs = lhs[next(iter(lhs_fields))] # type: ignore - lhs_fields = assign_arg_fields(lhs) - while rhs_fields is not None and len(rhs_fields) == 1: - rhs = rhs[next(iter(rhs_fields))] # type: ignore - rhs_fields = assign_arg_fields(rhs) - - def has_explicit_shape(val: ValueLike): - return isinstance(val, (Signal, ArrayProxy, Slice, ValueCastable)) - - if ( - isinstance(lhs, ValueCastable) - or isinstance(rhs, ValueCastable) - or (lhs_strict or has_explicit_shape(lhs)) - and (rhs_strict or has_explicit_shape(rhs)) - ): - if valuelike_shape(lhs) != valuelike_shape(rhs): - raise ValueError( - "Shapes not matching: lhs: {} {} rhs: {} {}".format( - valuelike_shape(lhs), repr(lhs), valuelike_shape(rhs), repr(rhs) - ) - ) - - lhs_val = Value.cast(lhs) - rhs_val = Value.cast(rhs) - - yield lhs_val.eq(rhs_val) diff --git a/transactron/utils/data_repr.py b/transactron/utils/data_repr.py deleted file mode 100644 index acd7c7505..000000000 --- a/transactron/utils/data_repr.py +++ /dev/null @@ -1,143 +0,0 @@ -from collections.abc import Iterable, Mapping -from ._typing import ShapeLike, MethodLayout -from typing import Any, Sized -from statistics import fmean -from amaranth.lib.data import StructLayout - - -__all__ = [ - "make_hashable", - "align_to_power_of_two", - "align_down_to_power_of_two", - "bits_from_int", - "layout_subset", - "data_layout", - "signed_to_int", - "int_to_signed", - "neg", - "average_dict_of_lists", -] - - -def layout_subset(layout: StructLayout, *, fields: set[str]) -> StructLayout: - return StructLayout({item: value for item, value in layout.members.items() if item in fields}) - - -def make_hashable(val): - if isinstance(val, Mapping): - return frozenset(((k, make_hashable(v)) for k, v in val.items())) - elif isinstance(val, Iterable): - return (make_hashable(v) for v in val) - else: - return val - - -def align_to_power_of_two(num: int, power: int) -> int: - """Rounds up a number to the given power of two. - - Parameters - ---------- - num : int - The number to align. - power : int - The power of two to align to. - - Returns - ------- - int - The aligned number. - """ - mask = 2**power - 1 - if num & mask == 0: - return num - return (num & ~mask) + 2**power - - -def align_down_to_power_of_two(num: int, power: int) -> int: - """Rounds down a number to the given power of two. - - Parameters - ---------- - num : int - The number to align. - power : int - The power of two to align to. - - Returns - ------- - int - The aligned number. - """ - mask = 2**power - 1 - - return num & ~mask - - -def bits_from_int(num: int, lower: int, length: int): - """Returns [`lower`:`lower`+`length`) bits from integer `num`.""" - return (num >> lower) & ((1 << (length)) - 1) - - -def data_layout(val: ShapeLike) -> MethodLayout: - return [("data", val)] - - -def neg(x: int, xlen: int) -> int: - """ - Computes the negation of a number in the U2 system. - - Parameters - ---------- - x: int - Number in U2 system. - xlen : int - Bit width of x. - - Returns - ------- - return : int - Negation of x in the U2 system. - """ - return (-x) & (2**xlen - 1) - - -def int_to_signed(x: int, xlen: int) -> int: - """ - Converts a Python integer into its U2 representation. - - Parameters - ---------- - x: int - Signed Python integer. - xlen : int - Bit width of x. - - Returns - ------- - return : int - Representation of x in the U2 system. - """ - return x & (2**xlen - 1) - - -def signed_to_int(x: int, xlen: int) -> int: - """ - Changes U2 representation into Python integer - - Parameters - ---------- - x: int - Number in U2 system. - xlen : int - Bit width of x. - - Returns - ------- - return : int - Representation of x as signed Python integer. - """ - return x | -(x & (2 ** (xlen - 1))) - - -def average_dict_of_lists(d: Mapping[Any, Sized]) -> float: - return fmean(map(lambda xs: len(xs), d.values())) diff --git a/transactron/utils/debug_signals.py b/transactron/utils/debug_signals.py deleted file mode 100644 index 4442e4dd4..000000000 --- a/transactron/utils/debug_signals.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Optional -from amaranth import * -from ._typing import SignalBundle, HasDebugSignals -from collections.abc import Collection, Mapping - - -def auto_debug_signals(thing) -> SignalBundle: - """Automatic debug signal generation. - - Exposes class attributes with debug signals (Amaranth `Signal`\\s, - `Record`\\s, `Array`\\s and `Elaboratable`\\s, `Method`\\s, classes - which define `debug_signals`). Used for generating ``gtkw`` files in - tests, for use in ``gtkwave``. - """ - - def auto_debug_signals_internal(thing, *, _visited: set) -> Optional[SignalBundle]: - # Please note, that the set `_visited` is used to memorise visited elements - # to break reference cycles. There is only one instance of this set, for whole - # `auto_debug_signals` recursion stack. It is being mutated by adding to it more - # elements id, so that caller know what was visited by callee. - smap: dict[str, SignalBundle] = {} - - # Check for reference cycles e.g. Amaranth's MustUse - if id(thing) in _visited: - return None - _visited.add(id(thing)) - - match thing: - case HasDebugSignals(): - return thing.debug_signals() - # avoid infinite recursion (strings are `Collection`s of strings) - case str(): - return None - case Collection() | Mapping(): - match thing: - case Collection(): - f_iter = enumerate(thing) - case Mapping(): - f_iter = thing.items() - for i, e in f_iter: - sublist = auto_debug_signals_internal(e, _visited=_visited) - if sublist is not None: - smap[f"[{i}]"] = sublist - if smap: - return smap - return None - case Array(): - for i, e in enumerate(thing): - if isinstance(e, Record): - e.name = f"[{i}]" - return thing - case Signal() | Record(): - return thing - case _: - try: - vs = vars(thing) - except (KeyError, AttributeError, TypeError): - return None - - for v in vs: - a = getattr(thing, v) - - # ignore private fields (mostly to ignore _MustUse_context to get pretty print) - if v[0] == "_": - continue - - dsignals = auto_debug_signals_internal(a, _visited=_visited) - if dsignals is not None: - smap[v] = dsignals - if smap: - return smap - return None - - ret = auto_debug_signals_internal(thing, _visited=set()) - if ret is None: - return [] - return ret diff --git a/transactron/utils/depcache.py b/transactron/utils/depcache.py deleted file mode 100644 index 0fbe356c3..000000000 --- a/transactron/utils/depcache.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import TypeVar, Type, Any - -from transactron.utils import make_hashable - -__all__ = ["DependentCache"] - -T = TypeVar("T") - - -class DependentCache: - """ - Cache for classes, that depend on the `DependentCache` class itself. - - Cached classes may accept one positional argument in the constructor, where this `DependentCache` class will - be passed. Classes may define any number keyword arguments in the constructor and separate cache entry will - be created for each set of the arguments. - - Methods - ------- - get: T, **kwargs -> T - Gets class `cls` from cache. Caches `cls` reference if this is the first call for it. - Optionally accepts `kwargs` for additional arguments in `cls` constructor. - - """ - - def __init__(self): - self._depcache: dict[tuple[Type, Any], Type] = {} - - def get(self, cls: Type[T], **kwargs) -> T: - cache_key = make_hashable(kwargs) - v = self._depcache.get((cls, cache_key), None) - if v is None: - positional_count = cls.__init__.__code__.co_argcount - - # first positional arg is `self` field, second may be `DependentCache` - if positional_count > 2: - raise KeyError(f"Too many positional arguments in {cls!r} constructor") - - if positional_count > 1: - v = cls(self, **kwargs) - else: - v = cls(**kwargs) - self._depcache[(cls, cache_key)] = v - return v diff --git a/transactron/utils/dependencies.py b/transactron/utils/dependencies.py deleted file mode 100644 index 010b03f22..000000000 --- a/transactron/utils/dependencies.py +++ /dev/null @@ -1,160 +0,0 @@ -from collections import defaultdict - -from abc import abstractmethod, ABC -from typing import Any, Generic, TypeVar - - -__all__ = ["DependencyManager", "DependencyKey", "DependencyContext", "SimpleKey", "ListKey"] - -T = TypeVar("T") -U = TypeVar("U") - - -class DependencyKey(Generic[T, U], ABC): - """Base class for dependency keys. - - Dependency keys are used to access dependencies in the `DependencyManager`. - Concrete instances of dependency keys should be frozen data classes. - - Parameters - ---------- - lock_on_get: bool, default: True - Specifies if no new dependencies should be added to key if it was already read by `get_dependency`. - cache: bool, default: True - If true, result of the `combine` method is cached and subsequent calls to `get_dependency` - will return the value in the cache. Adding a new dependency clears the cache. - empty_valid: bool, default : False - Specifies if getting key dependency without any added dependencies is valid. If set to `False`, that - action would cause raising `KeyError`. - """ - - @abstractmethod - def combine(self, data: list[T]) -> U: - """Combine multiple dependencies with the same key. - - This method is used to generate the value returned from `get_dependency` - in the `DependencyManager`. It takes dependencies added to the key - using `add_dependency` and combines them to a single result. - - Different implementations of `combine` give different combining behavior - for different kinds of keys. - """ - raise NotImplementedError() - - @abstractmethod - def __hash__(self) -> int: - """The `__hash__` method is made abstract so that only concrete keys - can be instanced. It is automatically overridden in frozen data - classes. - """ - raise NotImplementedError() - - lock_on_get: bool = True - cache: bool = True - empty_valid: bool = False - - -class SimpleKey(Generic[T], DependencyKey[T, T]): - """Base class for simple dependency keys. - - Simple dependency keys are used when there is an one-to-one relation between - keys and dependencies. If more than one dependency is added to a simple key, - an error is raised. - - Parameters - ---------- - default_value: T - Specifies the default value returned when no dependencies are added. To - enable it `empty_valid` must be True. - """ - - default_value: T - - def combine(self, data: list[T]) -> T: - if len(data) == 0: - return self.default_value - if len(data) != 1: - raise RuntimeError(f"Key {self} assigned {len(data)} values, expected 1") - return data[0] - - -class ListKey(Generic[T], DependencyKey[T, list[T]]): - """Base class for list key. - - List keys are used when there is an one-to-many relation between keys - and dependecies. Provides list of dependencies. - """ - - empty_valid = True - - def combine(self, data: list[T]) -> list[T]: - return data - - -class DependencyManager: - """Dependency manager. - - Tracks dependencies across the core. - """ - - def __init__(self): - self.dependencies: defaultdict[DependencyKey, list] = defaultdict(list) - self.cache: dict[DependencyKey, Any] = {} - self.locked_dependencies: set[DependencyKey] = set() - - def add_dependency(self, key: DependencyKey[T, Any], dependency: T) -> None: - """Adds a new dependency to a key. - - Depending on the key type, a key can have a single dependency or - multple dependencies added to it. - """ - - if key in self.locked_dependencies: - raise KeyError(f"Trying to add dependency to {key} that was already read and is locked") - - self.dependencies[key].append(dependency) - - if key in self.cache: - del self.cache[key] - - def get_dependency(self, key: DependencyKey[Any, U]) -> U: - """Gets the dependency for a key. - - The way dependencies are interpreted is dependent on the key type. - """ - if not key.empty_valid and key not in self.dependencies: - raise KeyError(f"Dependency {key} not provided") - - if key in self.cache: - return self.cache[key] - - if key.lock_on_get: - self.locked_dependencies.add(key) - - val = key.combine(self.dependencies[key]) - - if key.cache: - self.cache[key] = val - - return val - - -class DependencyContext: - stack: list[DependencyManager] = [] - - def __init__(self, manager: DependencyManager): - self.manager = manager - - def __enter__(self): - self.stack.append(self.manager) - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - top = self.stack.pop() - assert self.manager is top - - @classmethod - def get(cls) -> DependencyManager: - if not cls.stack: - raise RuntimeError("DependencyContext stack is empty") - return cls.stack[-1] diff --git a/transactron/utils/gen.py b/transactron/utils/gen.py deleted file mode 100644 index 780e151cd..000000000 --- a/transactron/utils/gen.py +++ /dev/null @@ -1,258 +0,0 @@ -from dataclasses import dataclass, field -from dataclasses_json import dataclass_json -from typing import Optional, TypeAlias - -from amaranth import * -from amaranth.back import verilog -from amaranth.hdl import Fragment - -from transactron.core import TransactionManager -from transactron.core.keys import TransactionManagerKey -from transactron.core.manager import MethodMap -from transactron.lib.metrics import HardwareMetricsManager -from transactron.lib import logging -from transactron.utils.dependencies import DependencyContext -from transactron.utils.idgen import IdGenerator -from transactron.utils._typing import AbstractInterface -from transactron.profiler import ProfileData - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from amaranth.hdl._ast import SignalDict - - -__all__ = [ - "MetricLocation", - "GeneratedLog", - "GenerationInfo", - "generate_verilog", -] - -SignalHandle: TypeAlias = list[str] -"""The location of a signal is a list of Verilog identifiers that denote a path -consisting of module names (and the signal name at the end) leading -to the signal wire.""" - - -@dataclass_json -@dataclass -class MetricLocation: - """Information about the location of a metric in the generated Verilog code. - - Attributes - ---------- - regs : dict[str, SignalHandle] - The location of each register of that metric. - """ - - regs: dict[str, SignalHandle] = field(default_factory=dict) - - -@dataclass_json -@dataclass -class TransactionSignalsLocation: - """Information about transaction control signals in the generated Verilog code. - - Attributes - ---------- - request: list[str] - The location of the ``request`` signal. - runnable: list[str] - The location of the ``runnable`` signal. - grant: list[str] - The location of the ``grant`` signal. - """ - - request: list[str] - runnable: list[str] - grant: list[str] - - -@dataclass_json -@dataclass -class MethodSignalsLocation: - """Information about method control signals in the generated Verilog code. - - Attributes - ---------- - run: list[str] - The location of the ``run`` signal. - """ - - run: list[str] - - -@dataclass_json -@dataclass -class GeneratedLog(logging.LogRecordInfo): - """Information about a log record in the generated Verilog code. - - Attributes - ---------- - trigger_location : SignalHandle - The location of the trigger signal. - fields_location : list[SignalHandle] - Locations of the log fields. - """ - - trigger_location: SignalHandle - fields_location: list[SignalHandle] - - -@dataclass_json -@dataclass -class GenerationInfo: - """Various information about the generated circuit. - - Attributes - ---------- - metrics_location : dict[str, MetricInfo] - Mapping from a metric name to an object storing Verilog locations - of its registers. - logs : list[GeneratedLog] - Locations and metadata for all log records. - """ - - metrics_location: dict[str, MetricLocation] - transaction_signals_location: dict[int, TransactionSignalsLocation] - method_signals_location: dict[int, MethodSignalsLocation] - profile_data: ProfileData - logs: list[GeneratedLog] - - def encode(self, file_name: str): - """ - Encodes the generation information as JSON and saves it to a file. - """ - with open(file_name, "w") as fp: - fp.write(self.to_json()) # type: ignore - - @staticmethod - def decode(file_name: str) -> "GenerationInfo": - """ - Loads the generation information from a JSON file. - """ - with open(file_name, "r") as fp: - return GenerationInfo.from_json(fp.read()) # type: ignore - - -def escape_verilog_identifier(identifier: str) -> str: - """ - Escapes a Verilog identifier according to the language standard. - - From IEEE Std 1364-2001 (IEEE Standard VerilogĀ® Hardware Description Language) - - "2.7.1 Escaped identifiers - - Escaped identifiers shall start with the backslash character and end with white - space (space, tab, newline). They provide a means of including any of the printable ASCII - characters in an identifier (the decimal values 33 through 126, or 21 through 7E in hexadecimal)." - """ - - # The standard says how to escape a identifier, but not when. So this is - # a non-exhaustive list of characters that Yosys escapes (it is used - # by Amaranth when generating Verilog code). - characters_to_escape = [".", "$", "-"] - - for char in characters_to_escape: - if char in identifier: - return f"\\{identifier} " - - return identifier - - -def get_signal_location(signal: Signal, name_map: "SignalDict") -> SignalHandle: - raw_location = name_map[signal] - return raw_location - - -def collect_metric_locations(name_map: "SignalDict") -> dict[str, MetricLocation]: - metrics_location: dict[str, MetricLocation] = {} - - # Collect information about the location of metric registers in the generated code. - metrics_manager = HardwareMetricsManager() - for metric_name, metric in metrics_manager.get_metrics().items(): - metric_loc = MetricLocation() - for reg_name in metric.regs: - metric_loc.regs[reg_name] = get_signal_location( - metrics_manager.get_register_value(metric_name, reg_name), name_map - ) - - metrics_location[metric_name] = metric_loc - - return metrics_location - - -def collect_transaction_method_signals( - transaction_manager: TransactionManager, name_map: "SignalDict" -) -> tuple[dict[int, TransactionSignalsLocation], dict[int, MethodSignalsLocation]]: - transaction_signals_location: dict[int, TransactionSignalsLocation] = {} - method_signals_location: dict[int, MethodSignalsLocation] = {} - - method_map = MethodMap(transaction_manager.transactions) - get_id = IdGenerator() - - for transaction in method_map.transactions: - request_loc = get_signal_location(transaction.request, name_map) - runnable_loc = get_signal_location(transaction.runnable, name_map) - grant_loc = get_signal_location(transaction.grant, name_map) - transaction_signals_location[get_id(transaction)] = TransactionSignalsLocation( - request_loc, runnable_loc, grant_loc - ) - - for method in method_map.methods: - run_loc = get_signal_location(method.run, name_map) - method_signals_location[get_id(method)] = MethodSignalsLocation(run_loc) - - return (transaction_signals_location, method_signals_location) - - -def collect_logs(name_map: "SignalDict") -> list[GeneratedLog]: - logs: list[GeneratedLog] = [] - - # Get all records. - for record in logging.get_log_records(0): - trigger_loc = get_signal_location(record.trigger, name_map) - fields_loc = [get_signal_location(field, name_map) for field in record.fields] - log = GeneratedLog( - logger_name=record.logger_name, - level=record.level, - format_str=record.format_str, - location=record.location, - trigger_location=trigger_loc, - fields_location=fields_loc, - ) - logs.append(log) - - return logs - - -def generate_verilog( - elaboratable: Elaboratable, ports: Optional[list[Value]] = None, top_name: str = "top" -) -> tuple[str, GenerationInfo]: - # The ports logic is copied (and simplified) from amaranth.back.verilog.convert. - # Unfortunately, the convert function doesn't return the name map. - if ports is None and isinstance(elaboratable, AbstractInterface): - ports = [] - for _, _, value in elaboratable.signature.flatten(elaboratable): - ports.append(Value.cast(value)) - elif ports is None: - raise TypeError("The `generate_verilog()` function requires a `ports=` argument") - - fragment = Fragment.get(elaboratable, platform=None).prepare(ports=ports) - verilog_text, name_map = verilog.convert_fragment(fragment, name=top_name, emit_src=True, strip_internal_attrs=True) - - transaction_manager = DependencyContext.get().get_dependency(TransactionManagerKey()) - transaction_signals, method_signals = collect_transaction_method_signals( - transaction_manager, name_map # type: ignore - ) - profile_data, _ = ProfileData.make(transaction_manager) - gen_info = GenerationInfo( - metrics_location=collect_metric_locations(name_map), # type: ignore - transaction_signals_location=transaction_signals, - method_signals_location=method_signals, - profile_data=profile_data, - logs=collect_logs(name_map), - ) - - return verilog_text, gen_info diff --git a/transactron/utils/idgen.py b/transactron/utils/idgen.py deleted file mode 100644 index 459f3160e..000000000 --- a/transactron/utils/idgen.py +++ /dev/null @@ -1,15 +0,0 @@ -__all__ = ["IdGenerator"] - - -class IdGenerator: - def __init__(self): - self.id_map = dict[int, int]() - self.id_seq = 0 - - def __call__(self, obj): - try: - return self.id_map[id(obj)] - except KeyError: - self.id_seq += 1 - self.id_map[id(obj)] = self.id_seq - return self.id_seq diff --git a/transactron/utils/transactron_helpers.py b/transactron/utils/transactron_helpers.py deleted file mode 100644 index 1fae8827a..000000000 --- a/transactron/utils/transactron_helpers.py +++ /dev/null @@ -1,160 +0,0 @@ -import sys -from contextlib import contextmanager -from typing import Optional, Any, Concatenate, TypeGuard, TypeVar -from collections.abc import Callable, Mapping, Sequence -from ._typing import ROGraph, GraphCC, SrcLoc, MethodLayout, MethodStruct, ShapeLike, LayoutList, LayoutListField -from inspect import Parameter, signature -from itertools import count -from amaranth import * -from amaranth import tracer -from amaranth.lib.data import StructLayout - - -__all__ = [ - "longest_common_prefix", - "silence_mustuse", - "get_caller_class_name", - "def_helper", - "method_def_helper", - "mock_def_helper", - "get_src_loc", - "from_method_layout", - "make_layout", - "extend_layout", -] - -T = TypeVar("T") -U = TypeVar("U") - - -def _graph_ccs(gr: ROGraph[T]) -> list[GraphCC[T]]: - """_graph_ccs - - Find connected components in a graph. - - Parameters - ---------- - gr : Mapping[T, Iterable[T]] - Graph in which we should find connected components. Encoded using - adjacency lists. - - Returns - ------- - ccs : List[Set[T]] - Connected components of the graph `gr`. - """ - ccs = [] - cc = set() - visited = set() - - for v in gr.keys(): - q = [v] - while q: - w = q.pop() - if w in visited: - continue - visited.add(w) - cc.add(w) - q.extend(gr[w]) - if cc: - ccs.append(cc) - cc = set() - - return ccs - - -def longest_common_prefix(*seqs: Sequence[T]) -> Sequence[T]: - if not seqs: - raise ValueError("no arguments") - for i, letter_group in enumerate(zip(*seqs)): - if len(set(letter_group)) > 1: - return seqs[0][:i] - return min(seqs, key=lambda s: len(s)) - - -def has_first_param(func: Callable[..., T], name: str, tp: type[U]) -> TypeGuard[Callable[Concatenate[U, ...], T]]: - parameters = signature(func).parameters - return ( - len(parameters) >= 1 - and next(iter(parameters)) == name - and parameters[name].kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY} - and parameters[name].annotation in {Parameter.empty, tp} - ) - - -def def_helper(description, func: Callable[..., T], tp: type[U], arg: U, /, **kwargs) -> T: - try: - parameters = signature(func).parameters - except ValueError: - raise TypeError(f"Invalid python method signature for {func} (missing `self` for class-level mock?)") - - kw_parameters = set( - n for n, p in parameters.items() if p.kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY} - ) - if len(parameters) == 1 and has_first_param(func, "arg", tp): - return func(arg) - elif kw_parameters <= kwargs.keys(): - return func(**kwargs) - else: - raise TypeError(f"Invalid {description}: {func}") - - -def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T: - return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg) - - -def method_def_helper(method, func: Callable[..., T], arg: MethodStruct) -> T: - kwargs = {k: arg[k] for k in arg.shape().members} - return def_helper(f"method definition for {method}", func, MethodStruct, arg, **kwargs) - - -def get_caller_class_name(default: Optional[str] = None) -> tuple[Optional[Elaboratable], str]: - try: - for d in count(2): - caller_frame = sys._getframe(d) - if "self" in caller_frame.f_locals: - owner = caller_frame.f_locals["self"] - if isinstance(owner, Elaboratable): - return owner, owner.__class__.__name__ - except ValueError: - pass - - if default is not None: - return None, default - else: - raise RuntimeError("Not called from a method") - - -@contextmanager -def silence_mustuse(elaboratable: Elaboratable): - try: - yield - except Exception: - elaboratable._MustUse__silence = True # type: ignore - raise - - -def get_src_loc(src_loc: int | SrcLoc) -> SrcLoc: - return tracer.get_src_loc(1 + src_loc) if isinstance(src_loc, int) else src_loc - - -def from_layout_field(shape: ShapeLike | LayoutList) -> ShapeLike: - if isinstance(shape, list): - return from_method_layout(shape) - else: - return shape - - -def make_layout(*fields: LayoutListField) -> StructLayout: - return from_method_layout(fields) - - -def extend_layout(layout: StructLayout, *fields: LayoutListField) -> StructLayout: - return StructLayout(layout.members | from_method_layout(fields).members) - - -def from_method_layout(layout: MethodLayout) -> StructLayout: - if isinstance(layout, StructLayout): - return layout - else: - return StructLayout({k: from_layout_field(v) for k, v in layout})