Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved SimpleTestCircuit #33

Draft
wants to merge 2 commits into
base: tilk/redesign
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions test/testing/test_method_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,44 @@
from transactron.lib import *


class SimpleMethodMockTestCircuit(Elaboratable):
method: Required[Method]
wrapper: Provided[Method]

def __init__(self, width: int):
self.method = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width}))
self.wrapper = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width}))

def elaborate(self, platform):
m = TModule()

@def_method(m, self.wrapper)
def _(input):
return {"output": self.method(m, input).output + 1}

return m


class TestMethodMock(TestCaseWithSimulator):
async def process(self, sim: TestbenchContext):
for _ in range(20):
val = random.randrange(2**self.width)
ret = await self.dut.wrapper.call(sim, input=val)
assert ret.output == (val + 2) % 2**self.width

@def_method_mock(lambda self: self.dut.method, enable=lambda _: random.randint(0, 1))
def method_mock(self, input):
return {"output": input + 1}

def test_method_mock_simple(self):
random.seed(42)
self.width = 4
self.dut = SimpleTestCircuit(SimpleMethodMockTestCircuit(self.width))

with self.run_simulation(self.dut) as sim:
sim.add_testbench(self.process)


class ReverseMethodMockTestCircuit(Elaboratable):
def __init__(self, width):
self.method = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width}))
Expand Down
15 changes: 13 additions & 2 deletions transactron/core/method.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections.abc import Sequence
import enum

from transactron.utils import *
from amaranth import *
from amaranth import tracer
from typing import TYPE_CHECKING, Optional, Iterator, Unpack
from typing import TYPE_CHECKING, Annotated, Optional, Iterator, TypeAlias, TypeVar, Unpack
from .transaction_base import *
from contextlib import contextmanager
from transactron.utils.assign import AssignArg
Expand All @@ -19,7 +20,17 @@
from .transaction import Transaction # noqa: F401


__all__ = ["Method", "Methods"]
__all__ = ["MethodDir", "Provided", "Required", "Method", "Methods"]


class MethodDir(enum.Enum):
PROVIDED = enum.auto()
REQUIRED = enum.auto()


_T = TypeVar("_T")
Provided: TypeAlias = Annotated[_T, MethodDir.PROVIDED]
Required: TypeAlias = Annotated[_T, MethodDir.REQUIRED]


class Method(TransactionBase["Transaction | Method"]):
Expand Down
4 changes: 4 additions & 0 deletions transactron/lib/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def create(
method = Method(name=name, i=i, o=o, src_loc=get_src_loc(src_loc))
return Adapter(method, **kwargs)

def update_args(self, **kwargs: Unpack[AdapterBodyParams]):
self.kwargs.update(kwargs)
return self

def set(self, with_validate_arguments: Optional[bool]):
if with_validate_arguments is not None:
self.with_validate_arguments = with_validate_arguments
Expand Down
18 changes: 13 additions & 5 deletions transactron/testing/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from amaranth import *
from amaranth.sim import *
from amaranth.sim._async import SimulatorContext
from transactron.core.method import MethodDir
from transactron.lib.adapters import Adapter

from transactron.utils.dependencies import DependencyContext, DependencyManager
from .testbenchio import TestbenchIO
Expand Down Expand Up @@ -58,6 +60,7 @@ def __getattr__(self, name: str) -> Any:

def elaborate(self, platform):
def transform_methods_to_testbenchios(
adapter_type: type[Adapter] | type[AdapterTrans],
container: _T_nested_collection[Method | Methods],
) -> tuple[
_T_nested_collection["TestbenchIO"],
Expand All @@ -67,32 +70,37 @@ def transform_methods_to_testbenchios(
tb_list = []
mc_list = []
for elem in container:
tb, mc = transform_methods_to_testbenchios(elem)
tb, mc = transform_methods_to_testbenchios(adapter_type, elem)
tb_list.append(tb)
mc_list.append(mc)
return tb_list, ModuleConnector(*mc_list)
elif isinstance(container, dict):
tb_dict = {}
mc_dict = {}
for name, elem in container.items():
tb, mc = transform_methods_to_testbenchios(elem)
tb, mc = transform_methods_to_testbenchios(adapter_type, elem)
tb_dict[name] = tb
mc_dict[name] = mc
return tb_dict, ModuleConnector(*mc_dict)
elif isinstance(container, Methods):
tb_list = [TestbenchIO(AdapterTrans(method)) for method in container]
tb_list = [TestbenchIO(adapter_type(method)) for method in container]
return list(tb_list), ModuleConnector(*tb_list)
else:
tb = TestbenchIO(AdapterTrans(container))
tb = TestbenchIO(adapter_type(container))
return tb, tb

m = Module()

m.submodules.dut = self._dut
hints = self._dut.__class__.__annotations__

for name, attr in vars(self._dut).items():
if guard_nested_collection(attr, Method | Methods) and attr:
tb_cont, mc = transform_methods_to_testbenchios(attr)
if name in hints and MethodDir.REQUIRED in hints[name].__metadata__:
adapter_type = Adapter
else: # PROVIDED is the default
adapter_type = AdapterTrans
tb_cont, mc = transform_methods_to_testbenchios(adapter_type, attr)
self._io[name] = tb_cont
m.submodules[name] = mc

Expand Down
9 changes: 8 additions & 1 deletion transactron/testing/method_mock.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from contextlib import contextmanager
import functools
from typing import Callable, Any, Optional
from typing import Callable, Any, Optional, Unpack

from amaranth.sim._async import SimulatorContext
from transactron.core.body import AdapterBodyParams
from transactron.lib.adapters import Adapter, AdapterBase
from transactron.utils.transactron_helpers import async_mock_def_helper
from .testbenchio import TestbenchIO
Expand All @@ -21,7 +22,13 @@ def __init__(
validate_arguments: Optional[Callable[..., bool]] = None,
enable: Callable[[], bool] = lambda: True,
delay: float = 0,
**kwargs: Unpack[AdapterBodyParams],
):
if isinstance(adapter, Adapter):
adapter.set(with_validate_arguments=validate_arguments is not None).update_args(**kwargs)
else:
assert validate_arguments is None
assert kwargs == {}
self.adapter = adapter
self.function = function
self.validate_arguments = validate_arguments
Expand Down