Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow using AdapterBase in MethodMock #16

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions test/testing/test_method_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import random
from amaranth import *
from amaranth.sim import *

from transactron import *
from transactron.testing import TestCaseWithSimulator, TestbenchContext
from transactron.testing.infrastructure import SimpleTestCircuit
from transactron.testing.method_mock import MethodMock, def_method_mock
from transactron.lib import *


class ReverseMethodMockTestCircuit(Elaboratable):
def __init__(self, width):
self.method = Method(i=from_method_layout([("input", width)]), o=from_method_layout([("output", width)]))

def elaborate(self, platform):
m = TModule()

@def_method(m, self.method)
def _(input):
return input + 1

return m


class TestReverseMethodMock(TestCaseWithSimulator):
async def active(self, sim: TestbenchContext):
for _ in range(10):
await sim.tick()

@def_method_mock(lambda self: self.m.method, enable=lambda _: random.randint(0, 1))
def method_mock(self, output: int):
input = random.randrange(0, 2**self.width)

@MethodMock.effect
def _():
assert output == (input + 1) % 2**self.width

return {"input": input}

def test_reverse_method_mock(self):
random.seed(42)
self.width = 4
self.m = SimpleTestCircuit(ReverseMethodMockTestCircuit(self.width))
self.accepted_val = 0
with self.run_simulation(self.m) as sim:
sim.add_testbench(self.active)
7 changes: 4 additions & 3 deletions transactron/testing/method_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable, Any, Optional

from amaranth.sim._async import SimulatorContext
from transactron.lib.adapters import Adapter
from transactron.lib.adapters import Adapter, AdapterBase
from transactron.utils.transactron_helpers import async_mock_def_helper
from .testbenchio import TestbenchIO
from transactron.utils._typing import RecordIntDict
Expand All @@ -15,7 +15,7 @@
class MethodMock:
def __init__(
self,
adapter: Adapter,
adapter: AdapterBase,
function: Callable[..., Optional[RecordIntDict]],
*,
validate_arguments: Optional[Callable[..., bool]] = None,
Expand Down Expand Up @@ -63,6 +63,8 @@ async def output_process(

async def validate_arguments_process(self, sim: SimulatorContext) -> None:
assert self.validate_arguments is not None
assert isinstance(self.adapter, Adapter)

sync = sim._design.lookup_domain("sync", None) # type: ignore
async for *args, clk, _ in (
sim.changed(*(a for a, _ in self.adapter.validators)).edge(sync.clk, 1).edge(self.adapter.en, 1)
Expand Down Expand Up @@ -166,7 +168,6 @@ def mock(func_self=None, /) -> MethodMock:
kw[k] = bind(func_self) if bind else v
tb = getter()
assert isinstance(tb, TestbenchIO)
assert isinstance(tb.adapter, Adapter)
return MethodMock(tb.adapter, f, **kw)

mock._transactron_method_mock = 1 # type: ignore
Expand Down