Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data dependent method calling #478

Merged
merged 11 commits into from
Nov 14, 2023
82 changes: 82 additions & 0 deletions test/transactions/test_methods.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from amaranth import *
from amaranth.sim import *

Expand Down Expand Up @@ -529,3 +530,84 @@ def process():

with self.run_simulation(circ) as sim:
sim.add_sync_process(process)


class DataDependentConditionalCircuit(Elaboratable):
def __init__(self, n=2, bad_number=3):
self.bad_number = bad_number
self.method = Method(i=data_layout(n))

self.in_t1 = Record(data_layout(n))
self.in_t2 = Record(data_layout(n))
self.ready = Signal()
self.req_t1 = Signal()
self.req_t2 = Signal()

self.out_m = Signal()
self.out_t1 = Signal()
self.out_t2 = Signal()

def elaborate(self, platform):
m = TModule()

@def_method(m, self.method, self.ready, ready_function=lambda rec: rec.data != self.bad_number)
def _(data):
m.d.comb += self.out_m.eq(1)

with Transaction().body(m, request=self.req_t1):
m.d.comb += self.out_t1.eq(1)
self.method(m, self.in_t1)

with Transaction().body(m, request=self.req_t2):
m.d.comb += self.out_t2.eq(1)
self.method(m, self.in_t2)

return m


class TestDataDependentConditionalMethod(TestCaseWithSimulator):
def setUp(self):
self.test_number = 200
self.bad_number = 3
self.n = 2
self.circ = DataDependentConditionalCircuit(self.n, self.bad_number)

def test_random(self):
random.seed(14)

def process():
for _ in range(self.test_number):
in1 = random.randrange(0, 2**self.n)
in2 = random.randrange(0, 2**self.n)
m_ready = random.randrange(2)
req_t1 = random.randrange(2)
req_t2 = random.randrange(2)

yield self.circ.in_t1.eq(in1)
yield self.circ.in_t2.eq(in2)
yield self.circ.req_t1.eq(req_t1)
yield self.circ.req_t2.eq(req_t2)
yield self.circ.ready.eq(m_ready)
yield Settle()
yield Delay(1e-8)

out_m = yield self.circ.out_m
out_t1 = yield self.circ.out_t1
out_t2 = yield self.circ.out_t2

if not m_ready or (not req_t1 or in1 == self.bad_number) and (not req_t2 or in2 == self.bad_number):
self.assertEqual(out_m, 0)
self.assertEqual(out_t1, 0)
self.assertEqual(out_t2, 0)
continue
# Here method global ready signal is high and we requested one of the transactions
# we also know that one of the transactions request correct input data

self.assertEqual(out_m, 1)
self.assertEqual(out_t1 ^ out_t2, 1)
# inX == self.bad_number implies out_tX==0
self.assertTrue(in1 != self.bad_number or not out_t1)
self.assertTrue(in2 != self.bad_number or not out_t2)

with self.run_simulation(self.circ, 100) as sim:
sim.add_process(process)
35 changes: 28 additions & 7 deletions transactron/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,17 @@ class MethodMap:
def __init__(self, transactions: Iterable["Transaction"]):
self.methods_by_transaction = dict[Transaction, list[Method]]()
self.transactions_by_method = defaultdict[Method, list[Transaction]](list)
self.arguments_by_method_by_transaction = defaultdict[Transaction, dict[Method, Record]](dict)
lekcyjna123 marked this conversation as resolved.
Show resolved Hide resolved

def rec(transaction: Transaction, source: TransactionBase):
for method in source.method_uses.keys():
for method, (arg_rec, _) in source.method_uses.items():
if not method.defined:
raise RuntimeError(f"Trying to use method '{method.name}' which is not defined yet")
if method in self.methods_by_transaction[transaction]:
raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction")
self.methods_by_transaction[transaction].append(method)
self.transactions_by_method[method].append(transaction)
self.arguments_by_method_by_transaction[transaction][method] = arg_rec
rec(transaction, method)

for transaction in transactions:
Expand Down Expand Up @@ -127,7 +129,10 @@ def eager_deterministic_cc_scheduler(
ccl = list(cc)
ccl.sort(key=lambda transaction: porder[transaction])
for k, transaction in enumerate(ccl):
ready = [method.ready for method in method_map.methods_by_transaction[transaction]]
ready = [
lekcyjna123 marked this conversation as resolved.
Show resolved Hide resolved
method.ready_function(method_map.arguments_by_method_by_transaction[transaction][method])
for method in method_map.methods_by_transaction[transaction]
]
runnable = Cat(ready).all()
conflicts = [ccl[j].grant for j in range(k) if ccl[j] in gr[transaction]]
noconflict = ~Cat(conflicts).any()
Expand Down Expand Up @@ -678,7 +683,7 @@ class TransactionBase(Owned):
name: str

def __init__(self):
self.method_uses: dict[Method, Tuple[ValueLike, ValueLike]] = dict()
self.method_uses: dict[Method, Tuple[Record, ValueLike]] = dict()
self.relations: list[RelationBase] = []
self.simultaneous_list: list[TransactionOrMethod] = []
self.independent_list: list[TransactionOrMethod] = []
Expand Down Expand Up @@ -714,7 +719,7 @@ def schedule_before(self, end: TransactionOrMethod) -> None:
"""
self.relations.append(RelationBase(end=end, priority=Priority.LEFT, conflict=False))

def use_method(self, method: "Method", arg: ValueLike, enable: ValueLike):
def use_method(self, method: "Method", arg: Record, enable: ValueLike):
if method in self.method_uses:
raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction '{self.name}'")
self.method_uses[method] = (arg, enable)
Expand Down Expand Up @@ -979,6 +984,7 @@ def __init__(
self.data_out = Record(o)
self.nonexclusive = nonexclusive
self.single_caller = single_caller
self.user_ready_function: Optional[Callable[[Record], ValueLike]] = None
if nonexclusive:
assert len(self.data_in) == 0

Expand Down Expand Up @@ -1022,7 +1028,14 @@ def _(arg):
return method(m, arg)

@contextmanager
def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0)) -> Iterator[Record]:
def body(
self,
m: TModule,
*,
ready: ValueLike = C(1),
out: ValueLike = C(0, 0),
user_ready_function: Optional[Callable[[Record], ValueLike]] = None,
lekcyjna123 marked this conversation as resolved.
Show resolved Hide resolved
) -> Iterator[Record]:
"""Define method body

The `body` context manager can be used to define the actions
Expand Down Expand Up @@ -1066,6 +1079,7 @@ def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0))
if self.defined:
raise RuntimeError(f"Method '{self.name}' already defined")
self.def_order = next(TransactionBase.def_counter)
self.user_ready_function = user_ready_function

try:
m.d.av_comb += self.ready.eq(ready)
Expand All @@ -1076,6 +1090,11 @@ def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0))
finally:
self.defined = True

def ready_function(self, arg_rec: Record) -> ValueLike:
if self.user_ready_function is not None:
return self.ready & self.user_ready_function(arg_rec)
return self.ready
lekcyjna123 marked this conversation as resolved.
Show resolved Hide resolved

def __call__(
self, m: TModule, arg: Optional[RecordDict] = None, enable: ValueLike = C(1), /, **kwargs: RecordDict
) -> Record:
Expand Down Expand Up @@ -1147,7 +1166,9 @@ def debug_signals(self) -> SignalBundle:
return [self.ready, self.run, self.data_in, self.data_out]


def def_method(m: TModule, method: Method, ready: ValueLike = C(1)):
def def_method(
m: TModule, method: Method, ready: ValueLike = C(1), ready_function: Optional[Callable[[Record], ValueLike]] = None
lekcyjna123 marked this conversation as resolved.
Show resolved Hide resolved
tilk marked this conversation as resolved.
Show resolved Hide resolved
):
"""Define a method.

This decorator allows to define transactional methods in an
Expand Down Expand Up @@ -1207,7 +1228,7 @@ def decorator(func: Callable[..., Optional[RecordDict]]):
out = Record.like(method.data_out)
ret_out = None

with method.body(m, ready=ready, out=out) as arg:
with method.body(m, ready=ready, out=out, user_ready_function=ready_function) as arg:
ret_out = method_def_helper(method, func, arg, **arg.fields)

if ret_out is not None:
Expand Down