Skip to content

Commit

Permalink
Allow using AdapterBase in MethodMock (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
piotro888 authored Nov 26, 2024
1 parent 4723228 commit 82aceee
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
47 changes: 47 additions & 0 deletions test/testing/test_method_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import random
from amaranth import *
from amaranth.sim import *

from transactron import *
from transactron.testing import TestCaseWithSimulator, TestbenchContext
from transactron.testing.infrastructure import SimpleTestCircuit
from transactron.testing.method_mock import MethodMock, def_method_mock
from transactron.lib import *


class ReverseMethodMockTestCircuit(Elaboratable):
def __init__(self, width):
self.method = Method(i=from_method_layout([("input", width)]), o=from_method_layout([("output", width)]))

def elaborate(self, platform):
m = TModule()

@def_method(m, self.method)
def _(input):
return input + 1

return m


class TestReverseMethodMock(TestCaseWithSimulator):
async def active(self, sim: TestbenchContext):
for _ in range(10):
await sim.tick()

@def_method_mock(lambda self: self.m.method, enable=lambda _: random.randint(0, 1))
def method_mock(self, output: int):
input = random.randrange(0, 2**self.width)

@MethodMock.effect
def _():
assert output == (input + 1) % 2**self.width

return {"input": input}

def test_reverse_method_mock(self):
random.seed(42)
self.width = 4
self.m = SimpleTestCircuit(ReverseMethodMockTestCircuit(self.width))
self.accepted_val = 0
with self.run_simulation(self.m) as sim:
sim.add_testbench(self.active)
7 changes: 4 additions & 3 deletions transactron/testing/method_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable, Any, Optional

from amaranth.sim._async import SimulatorContext
from transactron.lib.adapters import Adapter
from transactron.lib.adapters import Adapter, AdapterBase
from transactron.utils.transactron_helpers import async_mock_def_helper
from .testbenchio import TestbenchIO
from transactron.utils._typing import RecordIntDict
Expand All @@ -15,7 +15,7 @@
class MethodMock:
def __init__(
self,
adapter: Adapter,
adapter: AdapterBase,
function: Callable[..., Optional[RecordIntDict]],
*,
validate_arguments: Optional[Callable[..., bool]] = None,
Expand Down Expand Up @@ -63,6 +63,8 @@ async def output_process(

async def validate_arguments_process(self, sim: SimulatorContext) -> None:
assert self.validate_arguments is not None
assert isinstance(self.adapter, Adapter)

sync = sim._design.lookup_domain("sync", None) # type: ignore
async for *args, clk, _ in (
sim.changed(*(a for a, _ in self.adapter.validators)).edge(sync.clk, 1).edge(self.adapter.en, 1)
Expand Down Expand Up @@ -166,7 +168,6 @@ def mock(func_self=None, /) -> MethodMock:
kw[k] = bind(func_self) if bind else v
tb = getter()
assert isinstance(tb, TestbenchIO)
assert isinstance(tb.adapter, Adapter)
return MethodMock(tb.adapter, f, **kw)

mock._transactron_method_mock = 1 # type: ignore
Expand Down

0 comments on commit 82aceee

Please sign in to comment.