Skip to content

Commit

Permalink
Merge branch 'master' into component-typing
Browse files Browse the repository at this point in the history
  • Loading branch information
piotro888 authored Jan 18, 2025
2 parents 294016e + f15976e commit a83193b
Show file tree
Hide file tree
Showing 37 changed files with 753 additions and 469 deletions.
6 changes: 3 additions & 3 deletions docs/transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ Suppose we have the following layout, which is an input layout for a method call

```python
layout = [("foo", 1), ("bar", 32)]
method = Method(input_layout=layout)
method = Method(i=layout)
```

The method can be called in multiple ways.
Expand Down Expand Up @@ -170,7 +170,7 @@ Take the following definitions:

```python
layout2 = [("foobar", layout), ("baz", 42)]
method2 = Method(input_layout=layout2)
method2 = Method(i=layout2)
```

One can then pass the arguments using `dict`s in following ways:
Expand Down Expand Up @@ -208,7 +208,7 @@ The `dict` syntax can be used for returning values from methods.
Take the following method declaration:

```python
method3 = Method(input_layout=layout, output_layout=layout2)
method3 = Method(i=layout, o=layout2)
```

One can then define this method as follows:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "transactron"
dynamic = ["version"]
dependencies = [
"amaranth == 0.5.3",
"amaranth-stubs @ git+https://github.com/piotro888/amaranth-stubs.git@e25a7fa11d4a0d66ed18190f31b60914f222e74c"
"amaranth-stubs @ git+https://github.com/piotro888/amaranth-stubs.git@e25a7fa11d4a0d66ed18190f31b60914f222e74c",
"dataclasses-json == 0.6.3",
"tabulate == 0.9.0"
]
Expand Down
42 changes: 25 additions & 17 deletions test/core/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from collections import deque
from typing import Iterable, Callable
from transactron.core.keys import TransactionManagerKey

from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout

Expand All @@ -20,33 +21,36 @@
from transactron.core import Priority
from transactron.core.schedulers import trivial_roundrobin_cc_scheduler, eager_deterministic_cc_scheduler
from transactron.core.manager import TransactionScheduler
from transactron.utils.dependencies import DependencyContext
from transactron.utils.dependencies import DependencyContext, DependencyManager


class TestNames(TestCase):
def test_names(self):
mgr = TransactionManager()
mgr._MustUse__silence = True # type: ignore

class T(Elaboratable):
def __init__(self):
self._MustUse__silence = True # type: ignore
Transaction(manager=mgr)
with DependencyContext(DependencyManager()) as ctx:
ctx.manager.add_dependency(TransactionManagerKey(), mgr)

T()
assert mgr.transactions[0].name == "T"
class T(Elaboratable):
def __init__(self):
self._MustUse__silence = True # type: ignore
Transaction()

t = Transaction(name="x", manager=mgr)
assert t.name == "x"
T()
assert mgr.transactions[0].name == "T"

t = Transaction(manager=mgr)
assert t.name == "t"
t = Transaction(name="x")
assert t.name == "x"

m = Method(name="x")
assert m.name == "x"
t = Transaction()
assert t.name == "t"

m = Method()
assert m.name == "m"
m = Method(name="x")
assert m.name == "x"

m = Method()
assert m.name == "m"


class TestScheduler(TestCaseWithSimulator):
Expand Down Expand Up @@ -111,7 +115,7 @@ def __init__(self, scheduler):
def elaborate(self, platform):
m = TModule()
tm = TransactionModule(m, DependencyContext.get(), TransactionManager(self.scheduler))
adapter = Adapter(i=data_layout(32), o=data_layout(32))
adapter = Adapter.create(i=data_layout(32), o=data_layout(32))
m.submodules.out = self.out = TestbenchIO(adapter)
m.submodules.in1 = self.in1 = TestbenchIO(AdapterTrans(adapter.iface))
m.submodules.in2 = self.in2 = TestbenchIO(AdapterTrans(adapter.iface))
Expand Down Expand Up @@ -428,7 +432,11 @@ class SingleCallerTestCircuit(Elaboratable):
def elaborate(self, platform):
m = TModule()

method = Method(single_caller=True)
method = Method()

@def_method(m, method, single_caller=True)
def _():
pass

with Transaction().body(m):
method(m)
Expand Down
4 changes: 2 additions & 2 deletions test/lib/test_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def elaborate(self, platform):

get_results = []
for i in range(self.count):
input = TestbenchIO(Adapter(o=self.lay))
input = TestbenchIO(Adapter.create(o=self.lay))
get_results.append(input.adapter.iface)
m.submodules[f"input_{i}"] = input
self.inputs.append(input)

# Create ManyToOneConnectTrans, which will serialize results from different inputs
output = TestbenchIO(Adapter(i=self.lay))
output = TestbenchIO(Adapter.create(i=self.lay))
m.submodules.output = self.output = output
m.submodules.fu_arbitration = ManyToOneConnectTrans(get_results=get_results, put_result=output.adapter.iface)

Expand Down
4 changes: 2 additions & 2 deletions test/lib/test_reqres.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def setup_method(self):

layout = [("field", self.data_width)]

self.req_method = TestbenchIO(Adapter(i=layout))
self.resp_method = TestbenchIO(Adapter(o=layout))
self.req_method = TestbenchIO(Adapter.create(i=layout))
self.resp_method = TestbenchIO(Adapter.create(o=layout))

self.test_circuit = SimpleTestCircuit(
Serializer(
Expand Down
6 changes: 3 additions & 3 deletions test/lib/test_simultaneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
class ConditionTestCircuit(Elaboratable):
def __init__(self, target: Method, *, nonblocking: bool, priority: bool, catchall: bool):
self.target = target
self.source = Method(i=[("cond1", 1), ("cond2", 1), ("cond3", 1)], single_caller=True)
self.source = Method(i=[("cond1", 1), ("cond2", 1), ("cond3", 1)])
self.nonblocking = nonblocking
self.priority = priority
self.catchall = catchall

def elaborate(self, platform):
m = TModule()

@def_method(m, self.source)
@def_method(m, self.source, single_caller=True)
def _(cond1, cond2, cond3):
with condition(m, nonblocking=self.nonblocking, priority=self.priority) as branch:
with branch(cond1):
Expand All @@ -49,7 +49,7 @@ class TestCondition(TestCaseWithSimulator):
@pytest.mark.parametrize("priority", [False, True])
@pytest.mark.parametrize("catchall", [False, True])
def test_condition(self, nonblocking: bool, priority: bool, catchall: bool):
target = TestbenchIO(Adapter(i=[("cond", 2)]))
target = TestbenchIO(Adapter.create(i=[("cond", 2)]))

circ = SimpleTestCircuit(
ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall),
Expand Down
50 changes: 38 additions & 12 deletions test/lib/test_storage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections.abc import Callable
from amaranth_types import ShapeLike
import pytest
import random
from collections import deque
from datetime import timedelta
from hypothesis import given, settings, Phase
from transactron.testing import *
from transactron.lib.storage import *
from transactron.utils.transactron_helpers import make_layout


class TestContentAddressableMemory(TestCaseWithSimulator):
Expand Down Expand Up @@ -134,13 +137,20 @@ def test_random(self, in_push, in_write, in_read, in_remove):
sim.add_testbench(self.remove_process(in_remove))


bank_shapes = [
(6, lambda x: x, lambda x: x),
(make_layout(("data_field", 6)), lambda x: {"data_field": x}, lambda x: x["data_field"]),
]


class TestMemoryBank(TestCaseWithSimulator):
test_conf = [(9, 3, 3, 3, 14), (16, 1, 1, 3, 15), (16, 1, 1, 1, 16), (12, 3, 1, 1, 17), (9, 0, 0, 0, 18)]

@pytest.mark.parametrize("max_addr, writer_rand, reader_req_rand, reader_resp_rand, seed", test_conf)
@pytest.mark.parametrize("transparent", [False, True])
@pytest.mark.parametrize("read_ports", [1, 2])
@pytest.mark.parametrize("write_ports", [1, 2])
@pytest.mark.parametrize("shape,to_shape,from_shape", bank_shapes)
def test_mem(
self,
max_addr: int,
Expand All @@ -151,14 +161,16 @@ def test_mem(
transparent: bool,
read_ports: int,
write_ports: int,
shape: ShapeLike,
to_shape: Callable,
from_shape: Callable,
):
test_count = 200

data_width = 6
m = SimpleTestCircuit(
MemoryBank(
data_layout=[("data", data_width)],
elem_count=max_addr,
shape=shape,
depth=max_addr,
transparent=transparent,
read_ports=read_ports,
write_ports=write_ports,
Expand All @@ -173,9 +185,9 @@ def test_mem(
def writer(i):
async def process(sim: TestbenchContext):
for cycle in range(test_count):
d = random.randrange(2**data_width)
d = random.randrange(2 ** Shape.cast(shape).width)
a = random.randrange(max_addr)
await m.write[i].call(sim, data={"data": d}, addr=a)
await m.write[i].call(sim, data=to_shape(d), addr=a)
await sim.delay(1e-9 * (i + 2 if not transparent else i))
data[a] = d
await self.random_wait(sim, writer_rand)
Expand All @@ -202,7 +214,7 @@ async def process(sim: TestbenchContext):
await self.random_wait(sim, reader_resp_rand or 1, min_cycle_cnt=1)
await sim.delay(1e-9 * (write_ports + 3))
d = read_req_queues[i].popleft()
assert (await m.read_resp[i].call(sim)).data == d
assert from_shape((await m.read_resp[i].call(sim)).data) == d
await self.random_wait(sim, reader_resp_rand)

return process
Expand All @@ -224,13 +236,27 @@ class TestAsyncMemoryBank(TestCaseWithSimulator):
)
@pytest.mark.parametrize("read_ports", [1, 2])
@pytest.mark.parametrize("write_ports", [1, 2])
def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int, read_ports: int, write_ports: int):
@pytest.mark.parametrize("shape,to_shape,from_shape", bank_shapes)
def test_mem(
self,
max_addr: int,
writer_rand: int,
reader_rand: int,
seed: int,
read_ports: int,
write_ports: int,
shape: ShapeLike,
to_shape: Callable,
from_shape: Callable,
):
test_count = 200

data_width = 6
m = SimpleTestCircuit(
AsyncMemoryBank(
data_layout=[("data", data_width)], elem_count=max_addr, read_ports=read_ports, write_ports=write_ports
shape=shape,
depth=max_addr,
read_ports=read_ports,
write_ports=write_ports,
),
)

Expand All @@ -241,9 +267,9 @@ def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int,
def writer(i):
async def process(sim: TestbenchContext):
for cycle in range(test_count):
d = random.randrange(2**data_width)
d = random.randrange(2 ** Shape.cast(shape).width)
a = random.randrange(max_addr)
await m.write[i].call(sim, data={"data": d}, addr=a)
await m.write[i].call(sim, data=to_shape(d), addr=a)
await sim.delay(1e-9 * (i + 2))
data[a] = d
await self.random_wait(sim, writer_rand, min_cycle_cnt=1)
Expand All @@ -257,7 +283,7 @@ async def process(sim: TestbenchContext):
d = await m.read[i].call(sim, addr=a)
await sim.delay(1e-9)
expected_d = data[a]
assert d["data"] == expected_d
assert from_shape(d.data) == expected_d
await self.random_wait(sim, reader_rand, min_cycle_cnt=1)

return process
Expand Down
10 changes: 5 additions & 5 deletions test/lib/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def otransform_dict(_, v: MethodStruct) -> RecordDict:
itransform = itransform_rec
otransform = otransform_rec

m.submodules.target = self.target = TestbenchIO(Adapter(i=layout, o=layout))
m.submodules.target = self.target = TestbenchIO(Adapter.create(i=layout, o=layout))

if self.use_methods:
imeth = Method(i=layout, o=layout)
Expand Down Expand Up @@ -111,8 +111,8 @@ class TestMethodFilter(TestCaseWithSimulator):
def initialize(self):
self.iosize = 4
self.layout = data_layout(self.iosize)
self.target = TestbenchIO(Adapter(i=self.layout, o=self.layout))
self.cmeth = TestbenchIO(Adapter(i=self.layout, o=data_layout(1)))
self.target = TestbenchIO(Adapter.create(i=self.layout, o=self.layout))
self.cmeth = TestbenchIO(Adapter.create(i=self.layout, o=data_layout(1)))

async def source(self, sim: TestbenchContext):
for i in range(2**self.iosize):
Expand Down Expand Up @@ -165,7 +165,7 @@ def elaborate(self, platform):
methods = []

for k in range(self.targets):
tgt = TestbenchIO(Adapter(i=layout, o=layout))
tgt = TestbenchIO(Adapter.create(i=layout, o=layout))
methods.append(tgt.adapter.iface)
self.target.append(tgt)
m.submodules += tgt
Expand Down Expand Up @@ -280,7 +280,7 @@ def elaborate(self, platform):
methods = []

for k in range(self.targets):
tgt = TestbenchIO(Adapter(i=layout, o=layout))
tgt = TestbenchIO(Adapter.create(i=layout, o=layout))
methods.append(tgt.adapter.iface)
self.target.append(tgt)
m.submodules += tgt
Expand Down
Loading

0 comments on commit a83193b

Please sign in to comment.