diff --git a/coreblocks/utils/_typing.py b/coreblocks/utils/_typing.py index 5e27a229d..d12c2611b 100644 --- a/coreblocks/utils/_typing.py +++ b/coreblocks/utils/_typing.py @@ -6,6 +6,8 @@ TypeAlias, TypeVar, runtime_checkable, + Union, + Any, ) from collections.abc import Iterable, Mapping, Sequence from contextlib import AbstractContextManager @@ -30,6 +32,10 @@ SignalBundle: TypeAlias = Signal | Record | View | Iterable["SignalBundle"] | Mapping[str, "SignalBundle"] LayoutList: TypeAlias = list[tuple[str, "ShapeLike | LayoutList"]] +RecordIntDict: TypeAlias = Mapping[str, Union[int, "RecordIntDict"]] +RecordIntDictRet: TypeAlias = Mapping[str, Any] # full typing hard to work with +RecordValueDict: TypeAlias = Mapping[str, Union[ValueLike, "RecordValueDict"]] + class _ModuleBuilderDomainsLike(Protocol): def __getattr__(self, name: str) -> _ModuleBuilderDomain: diff --git a/test/common.py b/test/common.py deleted file mode 100644 index 76ede88f4..000000000 --- a/test/common.py +++ /dev/null @@ -1,461 +0,0 @@ -import unittest -import os -import functools -import random -from contextlib import contextmanager, nullcontext -from typing import Callable, Generic, Mapping, Union, Generator, TypeVar, Optional, Any, cast, Type, TypeGuard - -from amaranth import * -from amaranth.hdl.ast import Statement -from amaranth.sim import * -from amaranth.sim.core import Command - -from transactron.core import SignalBundle, Method, TransactionModule -from transactron.lib import AdapterBase, AdapterTrans -from transactron._utils import mock_def_helper -from coreblocks.utils import ValueLike, HasElaborate, HasDebugSignals, auto_debug_signals, LayoutLike, ModuleConnector -from .gtkw_extension import write_vcd_ext - - -T = TypeVar("T") -RecordValueDict = Mapping[str, Union[ValueLike, "RecordValueDict"]] -RecordIntDict = Mapping[str, Union[int, "RecordIntDict"]] -RecordIntDictRet = Mapping[str, Any] # full typing hard to work with -TestGen = Generator[Command | Value | Statement | None, Any, T] -_T_nested_collection = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"] - - -def data_layout(val: int) -> LayoutLike: - return [("data", val)] - - -def set_inputs(values: RecordValueDict, field: Record) -> TestGen[None]: - for name, value in values.items(): - if isinstance(value, dict): - yield from set_inputs(value, getattr(field, name)) - else: - yield getattr(field, name).eq(value) - - -def get_outputs(field: Record) -> TestGen[RecordIntDict]: - # return dict of all signal values in a record because amaranth's simulator can't read all - # values of a Record in a single yield - it can only read Values (Signals) - result = {} - for name, _, _ in field.layout: - val = getattr(field, name) - if isinstance(val, Signal): - result[name] = yield val - else: # field is a Record - result[name] = yield from get_outputs(val) - return result - - -def neg(x: int, xlen: int) -> int: - """ - Computes the negation of a number in the U2 system. - - Parameters - ---------- - x: int - Number in U2 system. - xlen : int - Bit width of x. - - Returns - ------- - return : int - Negation of x in the U2 system. - """ - return (-x) & (2**xlen - 1) - - -def int_to_signed(x: int, xlen: int) -> int: - """ - Converts a Python integer into its U2 representation. - - Parameters - ---------- - x: int - Signed Python integer. - xlen : int - Bit width of x. - - Returns - ------- - return : int - Representation of x in the U2 system. - """ - return x & (2**xlen - 1) - - -def signed_to_int(x: int, xlen: int) -> int: - """ - Changes U2 representation into Python integer - - Parameters - ---------- - x: int - Number in U2 system. - xlen : int - Bit width of x. - - Returns - ------- - return : int - Representation of x as signed Python integer. - """ - return x | -(x & (2 ** (xlen - 1))) - - -def guard_nested_collection(cont: Any, t: Type[T]) -> TypeGuard[_T_nested_collection[T]]: - if isinstance(cont, (list, dict)): - if isinstance(cont, dict): - cont = cont.values() - return all([guard_nested_collection(elem, t) for elem in cont]) - elif isinstance(cont, t): - return True - else: - return False - - -_T_HasElaborate = TypeVar("_T_HasElaborate", bound=HasElaborate) - - -class SimpleTestCircuit(Elaboratable, Generic[_T_HasElaborate]): - def __init__(self, dut: _T_HasElaborate): - self._dut = dut - self._io: dict[str, _T_nested_collection[TestbenchIO]] = {} - - def __getattr__(self, name: str) -> Any: - return self._io[name] - - def elaborate(self, platform): - def transform_methods_to_testbenchios( - container: _T_nested_collection[Method], - ) -> tuple[_T_nested_collection["TestbenchIO"], Union[ModuleConnector, "TestbenchIO"]]: - if isinstance(container, list): - tb_list = [] - mc_list = [] - for elem in container: - tb, mc = transform_methods_to_testbenchios(elem) - tb_list.append(tb) - mc_list.append(mc) - return tb_list, ModuleConnector(*mc_list) - elif isinstance(container, dict): - tb_dict = {} - mc_dict = {} - for name, elem in container.items(): - tb, mc = transform_methods_to_testbenchios(elem) - tb_dict[name] = tb - mc_dict[name] = mc - return tb_dict, ModuleConnector(*mc_dict) - else: - tb = TestbenchIO(AdapterTrans(container)) - return tb, tb - - m = Module() - - m.submodules.dut = self._dut - - for name, attr in vars(self._dut).items(): - if guard_nested_collection(attr, Method) and attr: - tb_cont, mc = transform_methods_to_testbenchios(attr) - self._io[name] = tb_cont - m.submodules[name] = mc - - return m - - def debug_signals(self): - sigs = {"_dut": auto_debug_signals(self._dut)} - for name, io in self._io.items(): - sigs[name] = auto_debug_signals(io) - return sigs - - -class TestModule(Elaboratable): - def __init__(self, tested_module: HasElaborate, add_transaction_module): - self.tested_module = TransactionModule(tested_module) if add_transaction_module else tested_module - self.add_transaction_module = add_transaction_module - - def elaborate(self, platform) -> HasElaborate: - m = Module() - - # so that Amaranth allows us to use add_clock - _dummy = Signal() - m.d.sync += _dummy.eq(1) - - m.submodules.tested_module = self.tested_module - - return m - - -class PysimSimulator(Simulator): - def __init__(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True, traces_file=None): - test_module = TestModule(module, add_transaction_module) - tested_module = test_module.tested_module - super().__init__(test_module) - - clk_period = 1e-6 - self.add_clock(clk_period) - - if isinstance(tested_module, HasDebugSignals): - extra_signals = tested_module.debug_signals - else: - extra_signals = functools.partial(auto_debug_signals, tested_module) - - if traces_file: - traces_dir = "test/__traces__" - os.makedirs(traces_dir, exist_ok=True) - # Signal handling is hacky and accesses Simulator internals. - # TODO: try to merge with Amaranth. - if isinstance(extra_signals, Callable): - extra_signals = extra_signals() - clocks = [d.clk for d in cast(Any, self)._fragment.domains.values()] - - self.ctx = write_vcd_ext( - cast(Any, self)._engine, - f"{traces_dir}/{traces_file}.vcd", - f"{traces_dir}/{traces_file}.gtkw", - traces=[clocks, extra_signals], - ) - else: - self.ctx = nullcontext() - - self.deadline = clk_period * max_cycles - - def run(self) -> bool: - with self.ctx: - self.run_until(self.deadline) - - return not self.advance() - - -class TestCaseWithSimulator(unittest.TestCase): - @contextmanager - def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True): - traces_file = None - if "__COREBLOCKS_DUMP_TRACES" in os.environ: - traces_file = unittest.TestCase.id(self) - - sim = PysimSimulator( - module, max_cycles=max_cycles, add_transaction_module=add_transaction_module, traces_file=traces_file - ) - yield sim - res = sim.run() - - self.assertTrue(res, "Simulation time limit exceeded") - - def tick(self, cycle_cnt=1): - """ - Yields for the given number of cycles. - """ - - for _ in range(cycle_cnt): - yield - - def random_wait(self, max_cycle_cnt): - """ - Wait for a random amount of cycles in range [1, max_cycle_cnt) - """ - yield from self.tick(random.randrange(max_cycle_cnt)) - - -class TestbenchIO(Elaboratable): - def __init__(self, adapter: AdapterBase): - self.adapter = adapter - - def elaborate(self, platform): - m = Module() - m.submodules += self.adapter - return m - - # Low-level operations - - def set_enable(self, en) -> TestGen[None]: - yield self.adapter.en.eq(1 if en else 0) - - def enable(self) -> TestGen[None]: - yield from self.set_enable(True) - - def disable(self) -> TestGen[None]: - yield from self.set_enable(False) - - def done(self) -> TestGen[int]: - return (yield self.adapter.done) - - def wait_until_done(self) -> TestGen[None]: - while (yield self.adapter.done) != 1: - yield - - def set_inputs(self, data: RecordValueDict = {}) -> TestGen[None]: - yield from set_inputs(data, self.adapter.data_in) - - def get_outputs(self) -> TestGen[RecordIntDictRet]: - return (yield from get_outputs(self.adapter.data_out)) - - # Operations for AdapterTrans - - def call_init(self, data: RecordValueDict = {}, /, **kwdata: ValueLike | RecordValueDict) -> TestGen[None]: - if data and kwdata: - raise TypeError("call_init() takes either a single dict or keyword arguments") - if not data: - data = kwdata - yield from self.enable() - yield from self.set_inputs(data) - - def call_result(self) -> TestGen[Optional[RecordIntDictRet]]: - if (yield from self.done()): - return (yield from self.get_outputs()) - return None - - def call_do(self) -> TestGen[RecordIntDict]: - while (outputs := (yield from self.call_result())) is None: - yield - yield from self.disable() - return outputs - - def call_try( - self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict - ) -> TestGen[Optional[RecordIntDictRet]]: - if data and kwdata: - raise TypeError("call_try() takes either a single dict or keyword arguments") - if not data: - data = kwdata - yield from self.call_init(data) - yield - outputs = yield from self.call_result() - yield from self.disable() - return outputs - - def call(self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict) -> TestGen[RecordIntDictRet]: - if data and kwdata: - raise TypeError("call() takes either a single dict or keyword arguments") - if not data: - data = kwdata - yield from self.call_init(data) - yield - return (yield from self.call_do()) - - # Operations for Adapter - - def method_argument(self) -> TestGen[Optional[RecordIntDictRet]]: - return (yield from self.call_result()) - - def method_return(self, data: RecordValueDict = {}) -> TestGen[None]: - yield from self.set_inputs(data) - - def method_handle( - self, - function: Callable[..., Optional[RecordIntDict]], - *, - enable: Optional[Callable[[], bool]] = None, - extra_settle_count: int = 0, - ) -> TestGen[None]: - enable = enable or (lambda: True) - yield from self.set_enable(enable()) - - # One extra Settle() required to propagate enable signal. - for _ in range(extra_settle_count + 1): - yield Settle() - while (arg := (yield from self.method_argument())) is None: - yield - yield from self.set_enable(enable()) - for _ in range(extra_settle_count + 1): - yield Settle() - - ret_out = mock_def_helper(self, function, arg) - yield from self.method_return(ret_out or {}) - yield - - def method_handle_loop( - self, - function: Callable[..., Optional[RecordIntDict]], - *, - enable: Optional[Callable[[], bool]] = None, - extra_settle_count: int = 0, - ) -> TestGen[None]: - yield Passive() - while True: - yield from self.method_handle(function, enable=enable, extra_settle_count=extra_settle_count) - - # Debug signals - - def debug_signals(self) -> SignalBundle: - return self.adapter.debug_signals() - - -def def_method_mock( - tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], sched_prio: int = 0, **kwargs -) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], TestGen[None]]]: - """ - Decorator function to create method mock handlers. It should be applied on - a function which describes functionality which we want to invoke on method call. - Such function will be wrapped by `method_handle_loop` and called on each - method invocation. - - Function `f` should take only one argument `arg` - data used in function - invocation - and should return data to be sent as response to the method call. - - Function `f` can also be a method and take two arguments `self` and `arg`, - the data to be passed on to invoke a method. It should return data to be sent - as response to the method call. - - Instead of the `arg` argument, the data can be split into keyword arguments. - - Make sure to defer accessing state, since decorators are evaluated eagerly - during function declaration. - - Parameters - ---------- - tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO] - Function to get the TestbenchIO providing appropriate `method_handle_loop`. - **kwargs - Arguments passed to `method_handle_loop`. - - Example - ------- - ``` - m = TestCircuit() - def target_process(k: int): - @def_method_mock(lambda: m.target[k]) - def process(arg): - return {"data": arg["data"] + k} - return process - ``` - or equivalently - ``` - m = TestCircuit() - def target_process(k: int): - @def_method_mock(lambda: m.target[k], settle=1, enable=False) - def process(data): - return {"data": data + k} - return process - ``` - or for class methods - ``` - @def_method_mock(lambda self: self.target[k], settle=1, enable=False) - def process(self, data): - return {"data": data + k} - ``` - """ - - def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[], TestGen[None]]: - @functools.wraps(func) - def mock(func_self=None, /) -> TestGen[None]: - f = func - getter: Any = tb_getter - kw = kwargs - if func_self is not None: - getter = getter.__get__(func_self) - f = f.__get__(func_self) - kw = {} - for k, v in kwargs.items(): - bind = getattr(v, "__get__", None) - kw[k] = bind(func_self) if bind else v - tb = getter() - assert isinstance(tb, TestbenchIO) - yield from tb.method_handle_loop(f, extra_settle_count=sched_prio, **kw) - - return mock - - return decorator diff --git a/test/common/__init__.py b/test/common/__init__.py new file mode 100644 index 000000000..20eb4e095 --- /dev/null +++ b/test/common/__init__.py @@ -0,0 +1,5 @@ +from .functions import * # noqa: F401 +from .infrastructure import * # noqa: F401 +from .sugar import * # noqa: F401 +from .testbenchio import * # noqa: F401 +from transactron._utils import data_layout # noqa: F401 diff --git a/test/common/functions.py b/test/common/functions.py new file mode 100644 index 000000000..c4ffc814a --- /dev/null +++ b/test/common/functions.py @@ -0,0 +1,29 @@ +from amaranth import * +from amaranth.hdl.ast import Statement +from amaranth.sim.core import Command +from typing import TypeVar, Any, Generator, TypeAlias +from coreblocks.utils._typing import RecordValueDict, RecordIntDict + +T = TypeVar("T") +TestGen: TypeAlias = Generator[Command | Value | Statement | None, Any, T] + + +def set_inputs(values: RecordValueDict, field: Record) -> TestGen[None]: + for name, value in values.items(): + if isinstance(value, dict): + yield from set_inputs(value, getattr(field, name)) + else: + yield getattr(field, name).eq(value) + + +def get_outputs(field: Record) -> TestGen[RecordIntDict]: + # return dict of all signal values in a record because amaranth's simulator can't read all + # values of a Record in a single yield - it can only read Values (Signals) + result = {} + for name, _, _ in field.layout: + val = getattr(field, name) + if isinstance(val, Signal): + result[name] = yield val + else: # field is a Record + result[name] = yield from get_outputs(val) + return result diff --git a/test/common/infrastructure.py b/test/common/infrastructure.py new file mode 100644 index 000000000..fe0b337d4 --- /dev/null +++ b/test/common/infrastructure.py @@ -0,0 +1,170 @@ +import os +import random +import unittest +import functools +from contextlib import contextmanager, nullcontext +from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast +from amaranth import * +from amaranth.sim import * +from .testbenchio import TestbenchIO +from ..gtkw_extension import write_vcd_ext +from transactron import Method +from transactron.lib import AdapterTrans +from transactron.core import TransactionModule +from coreblocks.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals + +T = TypeVar("T") +_T_nested_collection = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"] + + +def guard_nested_collection(cont: Any, t: Type[T]) -> TypeGuard[_T_nested_collection[T]]: + if isinstance(cont, (list, dict)): + if isinstance(cont, dict): + cont = cont.values() + return all([guard_nested_collection(elem, t) for elem in cont]) + elif isinstance(cont, t): + return True + else: + return False + + +_T_HasElaborate = TypeVar("_T_HasElaborate", bound=HasElaborate) + + +class SimpleTestCircuit(Elaboratable, Generic[_T_HasElaborate]): + def __init__(self, dut: _T_HasElaborate): + self._dut = dut + self._io: dict[str, _T_nested_collection[TestbenchIO]] = {} + + def __getattr__(self, name: str) -> Any: + return self._io[name] + + def elaborate(self, platform): + def transform_methods_to_testbenchios( + container: _T_nested_collection[Method], + ) -> tuple[_T_nested_collection["TestbenchIO"], Union[ModuleConnector, "TestbenchIO"]]: + if isinstance(container, list): + tb_list = [] + mc_list = [] + for elem in container: + tb, mc = transform_methods_to_testbenchios(elem) + tb_list.append(tb) + mc_list.append(mc) + return tb_list, ModuleConnector(*mc_list) + elif isinstance(container, dict): + tb_dict = {} + mc_dict = {} + for name, elem in container.items(): + tb, mc = transform_methods_to_testbenchios(elem) + tb_dict[name] = tb + mc_dict[name] = mc + return tb_dict, ModuleConnector(*mc_dict) + else: + tb = TestbenchIO(AdapterTrans(container)) + return tb, tb + + m = Module() + + m.submodules.dut = self._dut + + for name, attr in vars(self._dut).items(): + if guard_nested_collection(attr, Method) and attr: + tb_cont, mc = transform_methods_to_testbenchios(attr) + self._io[name] = tb_cont + m.submodules[name] = mc + + return m + + def debug_signals(self): + sigs = {"_dut": auto_debug_signals(self._dut)} + for name, io in self._io.items(): + sigs[name] = auto_debug_signals(io) + return sigs + + +class TestModule(Elaboratable): + def __init__(self, tested_module: HasElaborate, add_transaction_module): + self.tested_module = TransactionModule(tested_module) if add_transaction_module else tested_module + self.add_transaction_module = add_transaction_module + + def elaborate(self, platform) -> HasElaborate: + m = Module() + + # so that Amaranth allows us to use add_clock + _dummy = Signal() + m.d.sync += _dummy.eq(1) + + m.submodules.tested_module = self.tested_module + + return m + + +class PysimSimulator(Simulator): + def __init__(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True, traces_file=None): + test_module = TestModule(module, add_transaction_module) + tested_module = test_module.tested_module + super().__init__(test_module) + + clk_period = 1e-6 + self.add_clock(clk_period) + + if isinstance(tested_module, HasDebugSignals): + extra_signals = tested_module.debug_signals + else: + extra_signals = functools.partial(auto_debug_signals, tested_module) + + if traces_file: + traces_dir = "test/__traces__" + os.makedirs(traces_dir, exist_ok=True) + # Signal handling is hacky and accesses Simulator internals. + # TODO: try to merge with Amaranth. + if isinstance(extra_signals, Callable): + extra_signals = extra_signals() + clocks = [d.clk for d in cast(Any, self)._fragment.domains.values()] + + self.ctx = write_vcd_ext( + cast(Any, self)._engine, + f"{traces_dir}/{traces_file}.vcd", + f"{traces_dir}/{traces_file}.gtkw", + traces=[clocks, extra_signals], + ) + else: + self.ctx = nullcontext() + + self.deadline = clk_period * max_cycles + + def run(self) -> bool: + with self.ctx: + self.run_until(self.deadline) + + return not self.advance() + + +class TestCaseWithSimulator(unittest.TestCase): + @contextmanager + def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True): + traces_file = None + if "__COREBLOCKS_DUMP_TRACES" in os.environ: + traces_file = unittest.TestCase.id(self) + + sim = PysimSimulator( + module, max_cycles=max_cycles, add_transaction_module=add_transaction_module, traces_file=traces_file + ) + yield sim + res = sim.run() + + self.assertTrue(res, "Simulation time limit exceeded") + + def tick(self, cycle_cnt=1): + """ + Yields for the given number of cycles. + """ + + for _ in range(cycle_cnt): + yield + + def random_wait(self, max_cycle_cnt): + """ + Wait for a random amount of cycles in range [1, max_cycle_cnt) + """ + yield from self.tick(random.randrange(max_cycle_cnt)) diff --git a/test/common/sugar.py b/test/common/sugar.py new file mode 100644 index 000000000..beb4acf3a --- /dev/null +++ b/test/common/sugar.py @@ -0,0 +1,81 @@ +import functools +from typing import Callable, Any, Optional +from .testbenchio import TestbenchIO, TestGen +from coreblocks.utils._typing import RecordIntDict + + +def def_method_mock( + tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], sched_prio: int = 0, **kwargs +) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], TestGen[None]]]: + """ + Decorator function to create method mock handlers. It should be applied on + a function which describes functionality which we want to invoke on method call. + Such function will be wrapped by `method_handle_loop` and called on each + method invocation. + + Function `f` should take only one argument `arg` - data used in function + invocation - and should return data to be sent as response to the method call. + + Function `f` can also be a method and take two arguments `self` and `arg`, + the data to be passed on to invoke a method. It should return data to be sent + as response to the method call. + + Instead of the `arg` argument, the data can be split into keyword arguments. + + Make sure to defer accessing state, since decorators are evaluated eagerly + during function declaration. + + Parameters + ---------- + tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO] + Function to get the TestbenchIO providing appropriate `method_handle_loop`. + **kwargs + Arguments passed to `method_handle_loop`. + + Example + ------- + ``` + m = TestCircuit() + def target_process(k: int): + @def_method_mock(lambda: m.target[k]) + def process(arg): + return {"data": arg["data"] + k} + return process + ``` + or equivalently + ``` + m = TestCircuit() + def target_process(k: int): + @def_method_mock(lambda: m.target[k], settle=1, enable=False) + def process(data): + return {"data": data + k} + return process + ``` + or for class methods + ``` + @def_method_mock(lambda self: self.target[k], settle=1, enable=False) + def process(self, data): + return {"data": data + k} + ``` + """ + + def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[], TestGen[None]]: + @functools.wraps(func) + def mock(func_self=None, /) -> TestGen[None]: + f = func + getter: Any = tb_getter + kw = kwargs + if func_self is not None: + getter = getter.__get__(func_self) + f = f.__get__(func_self) + kw = {} + for k, v in kwargs.items(): + bind = getattr(v, "__get__", None) + kw[k] = bind(func_self) if bind else v + tb = getter() + assert isinstance(tb, TestbenchIO) + yield from tb.method_handle_loop(f, extra_settle_count=sched_prio, **kw) + + return mock + + return decorator diff --git a/test/common/testbenchio.py b/test/common/testbenchio.py new file mode 100644 index 000000000..9a2c956d8 --- /dev/null +++ b/test/common/testbenchio.py @@ -0,0 +1,132 @@ +from amaranth import * +from amaranth.sim import Settle, Passive +from typing import Optional, Callable +from transactron.lib import AdapterBase +from transactron.core import ValueLike, SignalBundle +from transactron._utils import mock_def_helper +from coreblocks.utils._typing import RecordIntDictRet, RecordValueDict, RecordIntDict +from .functions import set_inputs, get_outputs, TestGen + + +class TestbenchIO(Elaboratable): + def __init__(self, adapter: AdapterBase): + self.adapter = adapter + + def elaborate(self, platform): + m = Module() + m.submodules += self.adapter + return m + + # Low-level operations + + def set_enable(self, en) -> TestGen[None]: + yield self.adapter.en.eq(1 if en else 0) + + def enable(self) -> TestGen[None]: + yield from self.set_enable(True) + + def disable(self) -> TestGen[None]: + yield from self.set_enable(False) + + def done(self) -> TestGen[int]: + return (yield self.adapter.done) + + def wait_until_done(self) -> TestGen[None]: + while (yield self.adapter.done) != 1: + yield + + def set_inputs(self, data: RecordValueDict = {}) -> TestGen[None]: + yield from set_inputs(data, self.adapter.data_in) + + def get_outputs(self) -> TestGen[RecordIntDictRet]: + return (yield from get_outputs(self.adapter.data_out)) + + # Operations for AdapterTrans + + def call_init(self, data: RecordValueDict = {}, /, **kwdata: ValueLike | RecordValueDict) -> TestGen[None]: + if data and kwdata: + raise TypeError("call_init() takes either a single dict or keyword arguments") + if not data: + data = kwdata + yield from self.enable() + yield from self.set_inputs(data) + + def call_result(self) -> TestGen[Optional[RecordIntDictRet]]: + if (yield from self.done()): + return (yield from self.get_outputs()) + return None + + def call_do(self) -> TestGen[RecordIntDict]: + while (outputs := (yield from self.call_result())) is None: + yield + yield from self.disable() + return outputs + + def call_try( + self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict + ) -> TestGen[Optional[RecordIntDictRet]]: + if data and kwdata: + raise TypeError("call_try() takes either a single dict or keyword arguments") + if not data: + data = kwdata + yield from self.call_init(data) + yield + outputs = yield from self.call_result() + yield from self.disable() + return outputs + + def call(self, data: RecordIntDict = {}, /, **kwdata: int | RecordIntDict) -> TestGen[RecordIntDictRet]: + if data and kwdata: + raise TypeError("call() takes either a single dict or keyword arguments") + if not data: + data = kwdata + yield from self.call_init(data) + yield + return (yield from self.call_do()) + + # Operations for Adapter + + def method_argument(self) -> TestGen[Optional[RecordIntDictRet]]: + return (yield from self.call_result()) + + def method_return(self, data: RecordValueDict = {}) -> TestGen[None]: + yield from self.set_inputs(data) + + def method_handle( + self, + function: Callable[..., Optional[RecordIntDict]], + *, + enable: Optional[Callable[[], bool]] = None, + extra_settle_count: int = 0, + ) -> TestGen[None]: + enable = enable or (lambda: True) + yield from self.set_enable(enable()) + + # One extra Settle() required to propagate enable signal. + for _ in range(extra_settle_count + 1): + yield Settle() + while (arg := (yield from self.method_argument())) is None: + yield + yield from self.set_enable(enable()) + for _ in range(extra_settle_count + 1): + yield Settle() + + ret_out = mock_def_helper(self, function, arg) + yield from self.method_return(ret_out or {}) + yield + + def method_handle_loop( + self, + function: Callable[..., Optional[RecordIntDict]], + *, + enable: Optional[Callable[[], bool]] = None, + extra_settle_count: int = 0, + ) -> TestGen[None]: + yield Passive() + while True: + yield from self.method_handle(function, enable=enable, extra_settle_count=extra_settle_count) + + # Debug signals + + def debug_signals(self) -> SignalBundle: + return self.adapter.debug_signals() diff --git a/test/fu/test_alu.py b/test/fu/test_alu.py index 4e1cbcb7c..0aad519a0 100644 --- a/test/fu/test_alu.py +++ b/test/fu/test_alu.py @@ -3,7 +3,7 @@ from test.fu.functional_common import ExecFn, FunctionalUnitTestCase -from test.common import signed_to_int +from transactron._utils import signed_to_int class AluUnitTest(FunctionalUnitTestCase[AluFn.Fn]): diff --git a/test/fu/test_div_unit.py b/test/fu/test_div_unit.py index 158808369..baef39d17 100644 --- a/test/fu/test_div_unit.py +++ b/test/fu/test_div_unit.py @@ -5,7 +5,7 @@ from test.fu.functional_common import ExecFn, FunctionalUnitTestCase -from test.common import signed_to_int, int_to_signed +from transactron._utils import signed_to_int, int_to_signed @parameterized_class( diff --git a/test/fu/test_jb_unit.py b/test/fu/test_jb_unit.py index 5509e1306..a37a4cb0c 100644 --- a/test/fu/test_jb_unit.py +++ b/test/fu/test_jb_unit.py @@ -7,7 +7,7 @@ from coreblocks.params.layouts import FuncUnitLayouts, FetchLayouts from coreblocks.utils.protocols import FuncUnit -from test.common import signed_to_int +from transactron._utils import signed_to_int from test.fu.functional_common import ExecFn, FunctionalUnitTestCase diff --git a/test/fu/test_mul_unit.py b/test/fu/test_mul_unit.py index 5550aaa15..e97c89912 100644 --- a/test/fu/test_mul_unit.py +++ b/test/fu/test_mul_unit.py @@ -3,7 +3,7 @@ from coreblocks.params import * from coreblocks.fu.mul_unit import MulFn, MulComponent, MulType -from test.common import signed_to_int, int_to_signed +from transactron._utils import signed_to_int, int_to_signed from test.fu.functional_common import ExecFn, FunctionalUnitTestCase diff --git a/test/lsu/test_dummylsu.py b/test/lsu/test_dummylsu.py index d269049a3..bfa29d532 100644 --- a/test/lsu/test_dummylsu.py +++ b/test/lsu/test_dummylsu.py @@ -5,6 +5,7 @@ from amaranth.sim import Settle, Passive from transactron.lib import Adapter +from transactron._utils import int_to_signed, signed_to_int from coreblocks.params import OpType, GenParams from coreblocks.lsu.dummyLsu import LSUDummy from coreblocks.params.configurations import test_core_config @@ -13,7 +14,7 @@ from coreblocks.params.dependencies import DependencyManager from coreblocks.params.layouts import ExceptionRegisterLayouts from coreblocks.peripherals.wishbone import * -from test.common import TestbenchIO, TestCaseWithSimulator, def_method_mock, int_to_signed, signed_to_int +from test.common import TestbenchIO, TestCaseWithSimulator, def_method_mock from test.peripherals.test_wishbone import WishboneInterfaceWrapper diff --git a/transactron/_utils.py b/transactron/_utils.py index 138c7222b..bcd60d8db 100644 --- a/transactron/_utils.py +++ b/transactron/_utils.py @@ -4,7 +4,7 @@ from typing import Any, Concatenate, Optional, TypeAlias, TypeGuard, TypeVar from collections.abc import Callable, Iterable, Mapping from amaranth import * -from coreblocks.utils._typing import LayoutLike +from coreblocks.utils._typing import LayoutLike, ShapeLike from coreblocks.utils import OneHotSwitchDynamic __all__ = [ @@ -164,3 +164,64 @@ def get_caller_class_name(default: Optional[str] = None) -> tuple[Optional[Elabo return None, default else: raise RuntimeError("Not called from a method") + + +def data_layout(val: ShapeLike) -> LayoutLike: + return [("data", val)] + + +def neg(x: int, xlen: int) -> int: + """ + Computes the negation of a number in the U2 system. + + Parameters + ---------- + x: int + Number in U2 system. + xlen : int + Bit width of x. + + Returns + ------- + return : int + Negation of x in the U2 system. + """ + return (-x) & (2**xlen - 1) + + +def int_to_signed(x: int, xlen: int) -> int: + """ + Converts a Python integer into its U2 representation. + + Parameters + ---------- + x: int + Signed Python integer. + xlen : int + Bit width of x. + + Returns + ------- + return : int + Representation of x in the U2 system. + """ + return x & (2**xlen - 1) + + +def signed_to_int(x: int, xlen: int) -> int: + """ + Changes U2 representation into Python integer + + Parameters + ---------- + x: int + Number in U2 system. + xlen : int + Bit width of x. + + Returns + ------- + return : int + Representation of x as signed Python integer. + """ + return x | -(x & (2 ** (xlen - 1)))