Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autumn cleaning part 4 - split test/common.py #477

Merged
merged 10 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
461 changes: 0 additions & 461 deletions test/common.py

This file was deleted.

5 changes: 5 additions & 0 deletions test/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .functions import * # noqa: F401
from .infrastructure import * # noqa: F401
from .sugar import * # noqa: F401
from .testbenchio import * # noqa: F401
from transactron._utils import data_layout # noqa: F401
28 changes: 28 additions & 0 deletions test/common/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from amaranth import *
from typing import TYPE_CHECKING
from transactron._utils import RecordValueDict, RecordIntDict


if TYPE_CHECKING:
from .testbenchio import TestGen
lekcyjna123 marked this conversation as resolved.
Show resolved Hide resolved


def set_inputs(values: RecordValueDict, field: Record) -> "TestGen[None]":
tilk marked this conversation as resolved.
Show resolved Hide resolved
for name, value in values.items():
if isinstance(value, dict):
yield from set_inputs(value, getattr(field, name))
else:
yield getattr(field, name).eq(value)


def get_outputs(field: Record) -> "TestGen[RecordIntDict]":
# return dict of all signal values in a record because amaranth's simulator can't read all
# values of a Record in a single yield - it can only read Values (Signals)
result = {}
for name, _, _ in field.layout:
val = getattr(field, name)
if isinstance(val, Signal):
result[name] = yield val
else: # field is a Record
result[name] = yield from get_outputs(val)
return result
170 changes: 170 additions & 0 deletions test/common/infrastructure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import os
import random
import unittest
import functools
from contextlib import contextmanager, nullcontext
from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast
from amaranth import *
from amaranth.sim import *
from .testbenchio import TestbenchIO
from ..gtkw_extension import write_vcd_ext
from transactron import Method
from transactron.lib import AdapterTrans
from transactron.core import TransactionModule
from coreblocks.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals

T = TypeVar("T")
_T_nested_collection = T | list["_T_nested_collection[T]"] | dict[str, "_T_nested_collection[T]"]


def guard_nested_collection(cont: Any, t: Type[T]) -> TypeGuard[_T_nested_collection[T]]:
if isinstance(cont, (list, dict)):
if isinstance(cont, dict):
cont = cont.values()
return all([guard_nested_collection(elem, t) for elem in cont])
elif isinstance(cont, t):
return True
else:
return False


_T_HasElaborate = TypeVar("_T_HasElaborate", bound=HasElaborate)


class SimpleTestCircuit(Elaboratable, Generic[_T_HasElaborate]):
def __init__(self, dut: _T_HasElaborate):
self._dut = dut
self._io: dict[str, _T_nested_collection[TestbenchIO]] = {}

def __getattr__(self, name: str) -> Any:
return self._io[name]

def elaborate(self, platform):
def transform_methods_to_testbenchios(
container: _T_nested_collection[Method],
) -> tuple[_T_nested_collection["TestbenchIO"], Union[ModuleConnector, "TestbenchIO"]]:
if isinstance(container, list):
tb_list = []
mc_list = []
for elem in container:
tb, mc = transform_methods_to_testbenchios(elem)
tb_list.append(tb)
mc_list.append(mc)
return tb_list, ModuleConnector(*mc_list)
elif isinstance(container, dict):
tb_dict = {}
mc_dict = {}
for name, elem in container.items():
tb, mc = transform_methods_to_testbenchios(elem)
tb_dict[name] = tb
mc_dict[name] = mc
return tb_dict, ModuleConnector(*mc_dict)
else:
tb = TestbenchIO(AdapterTrans(container))
return tb, tb

m = Module()

m.submodules.dut = self._dut

for name, attr in vars(self._dut).items():
if guard_nested_collection(attr, Method) and attr:
tb_cont, mc = transform_methods_to_testbenchios(attr)
self._io[name] = tb_cont
m.submodules[name] = mc

return m

def debug_signals(self):
sigs = {"_dut": auto_debug_signals(self._dut)}
for name, io in self._io.items():
sigs[name] = auto_debug_signals(io)
return sigs


class TestModule(Elaboratable):
def __init__(self, tested_module: HasElaborate, add_transaction_module):
self.tested_module = TransactionModule(tested_module) if add_transaction_module else tested_module
self.add_transaction_module = add_transaction_module

def elaborate(self, platform) -> HasElaborate:
m = Module()

# so that Amaranth allows us to use add_clock
_dummy = Signal()
m.d.sync += _dummy.eq(1)

m.submodules.tested_module = self.tested_module

return m


class PysimSimulator(Simulator):
def __init__(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True, traces_file=None):
test_module = TestModule(module, add_transaction_module)
tested_module = test_module.tested_module
super().__init__(test_module)

clk_period = 1e-6
self.add_clock(clk_period)

if isinstance(tested_module, HasDebugSignals):
extra_signals = tested_module.debug_signals
else:
extra_signals = functools.partial(auto_debug_signals, tested_module)

if traces_file:
traces_dir = "test/__traces__"
os.makedirs(traces_dir, exist_ok=True)
# Signal handling is hacky and accesses Simulator internals.
# TODO: try to merge with Amaranth.
if isinstance(extra_signals, Callable):
extra_signals = extra_signals()
clocks = [d.clk for d in cast(Any, self)._fragment.domains.values()]

self.ctx = write_vcd_ext(
cast(Any, self)._engine,
f"{traces_dir}/{traces_file}.vcd",
f"{traces_dir}/{traces_file}.gtkw",
traces=[clocks, extra_signals],
)
else:
self.ctx = nullcontext()

self.deadline = clk_period * max_cycles

def run(self) -> bool:
with self.ctx:
self.run_until(self.deadline)

return not self.advance()


class TestCaseWithSimulator(unittest.TestCase):
@contextmanager
def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True):
traces_file = None
if "__COREBLOCKS_DUMP_TRACES" in os.environ:
traces_file = unittest.TestCase.id(self)

sim = PysimSimulator(
module, max_cycles=max_cycles, add_transaction_module=add_transaction_module, traces_file=traces_file
)
yield sim
res = sim.run()

self.assertTrue(res, "Simulation time limit exceeded")

def tick(self, cycle_cnt=1):
"""
Yields for the given number of cycles.
"""

for _ in range(cycle_cnt):
yield

def random_wait(self, max_cycle_cnt):
"""
Wait for a random amount of cycles in range [1, max_cycle_cnt)
"""
yield from self.tick(random.randrange(max_cycle_cnt))
81 changes: 81 additions & 0 deletions test/common/sugar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import functools
from typing import Callable, Any, Optional
from .testbenchio import TestbenchIO, TestGen
from transactron._utils import RecordIntDict


def def_method_mock(
tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], sched_prio: int = 0, **kwargs
) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], TestGen[None]]]:
"""
Decorator function to create method mock handlers. It should be applied on
a function which describes functionality which we want to invoke on method call.
Such function will be wrapped by `method_handle_loop` and called on each
method invocation.

Function `f` should take only one argument `arg` - data used in function
invocation - and should return data to be sent as response to the method call.

Function `f` can also be a method and take two arguments `self` and `arg`,
the data to be passed on to invoke a method. It should return data to be sent
as response to the method call.

Instead of the `arg` argument, the data can be split into keyword arguments.

Make sure to defer accessing state, since decorators are evaluated eagerly
during function declaration.

Parameters
----------
tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO]
Function to get the TestbenchIO providing appropriate `method_handle_loop`.
**kwargs
Arguments passed to `method_handle_loop`.

Example
-------
```
m = TestCircuit()
def target_process(k: int):
@def_method_mock(lambda: m.target[k])
def process(arg):
return {"data": arg["data"] + k}
return process
```
or equivalently
```
m = TestCircuit()
def target_process(k: int):
@def_method_mock(lambda: m.target[k], settle=1, enable=False)
def process(data):
return {"data": data + k}
return process
```
or for class methods
```
@def_method_mock(lambda self: self.target[k], settle=1, enable=False)
def process(self, data):
return {"data": data + k}
```
"""

def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[], TestGen[None]]:
@functools.wraps(func)
def mock(func_self=None, /) -> TestGen[None]:
f = func
getter: Any = tb_getter
kw = kwargs
if func_self is not None:
getter = getter.__get__(func_self)
f = f.__get__(func_self)
kw = {}
for k, v in kwargs.items():
bind = getattr(v, "__get__", None)
kw[k] = bind(func_self) if bind else v
tb = getter()
assert isinstance(tb, TestbenchIO)
yield from tb.method_handle_loop(f, extra_settle_count=sched_prio, **kw)

return mock

return decorator
Loading