Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data dependent method calling #478

Merged
merged 11 commits into from
Nov 14, 2023
88 changes: 88 additions & 0 deletions test/transactions/test_methods.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from amaranth import *
from amaranth.sim import *

Expand Down Expand Up @@ -529,3 +530,90 @@ def process():

with self.run_simulation(circ) as sim:
sim.add_sync_process(process)


class DataDependentConditionalCircuit(Elaboratable):
def __init__(self, n=2, ready_function=lambda arg: arg.data != 3):
self.method = Method(i=data_layout(n))
self.ready_function = ready_function

self.in_t1 = Record(data_layout(n))
self.in_t2 = Record(data_layout(n))
self.ready = Signal()
self.req_t1 = Signal()
self.req_t2 = Signal()

self.out_m = Signal()
self.out_t1 = Signal()
self.out_t2 = Signal()

def elaborate(self, platform):
m = TModule()

@def_method(m, self.method, self.ready, ready_function=self.ready_function)
def _(data):
m.d.comb += self.out_m.eq(1)

with Transaction().body(m, request=self.req_t1):
m.d.comb += self.out_t1.eq(1)
self.method(m, self.in_t1)

with Transaction().body(m, request=self.req_t2):
m.d.comb += self.out_t2.eq(1)
self.method(m, self.in_t2)

return m


class TestDataDependentConditionalMethod(TestCaseWithSimulator):
def setUp(self):
self.test_number = 200
self.bad_number = 3
self.n = 2

def base_random(self, f):
random.seed(14)
self.circ = DataDependentConditionalCircuit(n=self.n, ready_function=f)

def process():
for _ in range(self.test_number):
in1 = random.randrange(0, 2**self.n)
in2 = random.randrange(0, 2**self.n)
m_ready = random.randrange(2)
req_t1 = random.randrange(2)
req_t2 = random.randrange(2)

yield self.circ.in_t1.eq(in1)
yield self.circ.in_t2.eq(in2)
yield self.circ.req_t1.eq(req_t1)
yield self.circ.req_t2.eq(req_t2)
yield self.circ.ready.eq(m_ready)
yield Settle()
yield Delay(1e-8)

out_m = yield self.circ.out_m
out_t1 = yield self.circ.out_t1
out_t2 = yield self.circ.out_t2

if not m_ready or (not req_t1 or in1 == self.bad_number) and (not req_t2 or in2 == self.bad_number):
self.assertEqual(out_m, 0)
self.assertEqual(out_t1, 0)
self.assertEqual(out_t2, 0)
continue
# Here method global ready signal is high and we requested one of the transactions
# we also know that one of the transactions request correct input data

self.assertEqual(out_m, 1)
self.assertEqual(out_t1 ^ out_t2, 1)
# inX == self.bad_number implies out_tX==0
self.assertTrue(in1 != self.bad_number or not out_t1)
self.assertTrue(in2 != self.bad_number or not out_t2)

with self.run_simulation(self.circ, 100) as sim:
sim.add_process(process)

def test_random_arg(self):
self.base_random(lambda arg: arg.data != self.bad_number)

def test_random_kwarg(self):
self.base_random(lambda data: data != self.bad_number)
57 changes: 44 additions & 13 deletions transactron/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,17 @@ class MethodMap:
def __init__(self, transactions: Iterable["Transaction"]):
self.methods_by_transaction = dict[Transaction, list[Method]]()
self.transactions_by_method = defaultdict[Method, list[Transaction]](list)
self.readiness_by_method_and_transaction = dict[tuple[Transaction, Method], ValueLike]()

def rec(transaction: Transaction, source: TransactionBase):
for method in source.method_uses.keys():
for method, (arg_rec, _) in source.method_uses.items():
if not method.defined:
raise RuntimeError(f"Trying to use method '{method.name}' which is not defined yet")
if method in self.methods_by_transaction[transaction]:
raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction")
self.methods_by_transaction[transaction].append(method)
self.transactions_by_method[method].append(transaction)
self.readiness_by_method_and_transaction[(transaction, method)] = method._ready_function(arg_rec)
rec(transaction, method)

for transaction in transactions:
Expand Down Expand Up @@ -139,7 +141,10 @@ def eager_deterministic_cc_scheduler(
ccl = list(cc)
ccl.sort(key=lambda transaction: porder[transaction])
for k, transaction in enumerate(ccl):
ready = [method.ready for method in method_map.methods_by_transaction[transaction]]
ready = [
lekcyjna123 marked this conversation as resolved.
Show resolved Hide resolved
method_map.readiness_by_method_and_transaction[(transaction, method)]
for method in method_map.methods_by_transaction[transaction]
]
runnable = Cat(ready).all()
conflicts = [ccl[j].grant for j in range(k) if ccl[j] in gr[transaction]]
noconflict = ~Cat(conflicts).any()
Expand Down Expand Up @@ -175,11 +180,11 @@ def trivial_roundrobin_cc_scheduler(
sched = Scheduler(len(cc))
m.submodules.scheduler = sched
for k, transaction in enumerate(cc):
methods = method_map.methods_by_transaction[transaction]
ready = Signal(len(methods))
for n, method in enumerate(methods):
m.d.comb += ready[n].eq(method.ready)
runnable = ready.all()
ready = [
method_map.readiness_by_method_and_transaction[(transaction, method)]
for method in method_map.methods_by_transaction[transaction]
]
runnable = Cat(ready).all()
m.d.comb += sched.requests[k].eq(transaction.request & runnable)
m.d.comb += transaction.grant.eq(sched.grant[k] & sched.valid)
return m
Expand Down Expand Up @@ -689,13 +694,13 @@ class TransactionBase(Owned, Protocol):
def_order: int
defined: bool = False
name: str
method_uses: dict["Method", Tuple[ValueLike, ValueLike]]
method_uses: dict["Method", Tuple[Record, ValueLike]]
relations: list[RelationBase]
simultaneous_list: list[TransactionOrMethod]
independent_list: list[TransactionOrMethod]

def __init__(self):
self.method_uses: dict["Method", Tuple[ValueLike, ValueLike]] = dict()
self.method_uses: dict["Method", Tuple[Record, ValueLike]] = dict()
self.relations: list[RelationBase] = []
self.simultaneous_list: list[TransactionOrMethod] = []
self.independent_list: list[TransactionOrMethod] = []
Expand Down Expand Up @@ -731,7 +736,7 @@ def schedule_before(self, end: TransactionOrMethod) -> None:
"""
self.relations.append(RelationBase(end=end, priority=Priority.LEFT, conflict=False))

def use_method(self, method: "Method", arg: ValueLike, enable: ValueLike):
def use_method(self, method: "Method", arg: Record, enable: ValueLike):
if method in self.method_uses:
raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction '{self.name}'")
self.method_uses[method] = (arg, enable)
Expand Down Expand Up @@ -994,6 +999,7 @@ def __init__(
self.data_out = Record(o)
self.nonexclusive = nonexclusive
self.single_caller = single_caller
self.ready_function: Optional[Callable[[Record], ValueLike]] = None
if nonexclusive:
assert len(self.data_in) == 0

Expand Down Expand Up @@ -1037,7 +1043,14 @@ def _(arg):
return method(m, arg)

@contextmanager
def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0)) -> Iterator[Record]:
def body(
self,
m: TModule,
*,
ready: ValueLike = C(1),
out: ValueLike = C(0, 0),
ready_function: Optional[Callable[[Record], ValueLike]] = None,
) -> Iterator[Record]:
"""Define method body

The `body` context manager can be used to define the actions
Expand All @@ -1060,6 +1073,11 @@ def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0))
Data generated by the `Method`, which will be passed to
the caller (a `Transaction` or another `Method`). Assigned
combinationally to the `data_out` attribute.
ready_function: Optional[Callable[[Record], ValueLike]]
tilk marked this conversation as resolved.
Show resolved Hide resolved
Function to instantiate a combinational circuit for each
method caller. It should take input arguments and return
if the method can be called with those arguments. By default
there is no function, so all arguments are accepted.
lekcyjna123 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand All @@ -1081,6 +1099,7 @@ def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0))
if self.defined:
raise RuntimeError(f"Method '{self.name}' already defined")
self.def_order = next(TransactionBase.def_counter)
self.ready_function = ready_function

try:
m.d.av_comb += self.ready.eq(ready)
Expand All @@ -1091,6 +1110,11 @@ def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0))
finally:
self.defined = True

def _ready_function(self, arg_rec: Record) -> ValueLike:
if self.ready_function is not None:
return self.ready & method_def_helper(self, self.ready_function, arg_rec, **arg_rec.fields)
return self.ready

def __call__(
self, m: TModule, arg: Optional[RecordDict] = None, enable: ValueLike = C(1), /, **kwargs: RecordDict
) -> Record:
Expand Down Expand Up @@ -1162,7 +1186,9 @@ def debug_signals(self) -> SignalBundle:
return [self.ready, self.run, self.data_in, self.data_out]


def def_method(m: TModule, method: Method, ready: ValueLike = C(1)):
def def_method(
m: TModule, method: Method, ready: ValueLike = C(1), ready_function: Optional[Callable[[Record], ValueLike]] = None
lekcyjna123 marked this conversation as resolved.
Show resolved Hide resolved
tilk marked this conversation as resolved.
Show resolved Hide resolved
):
"""Define a method.

This decorator allows to define transactional methods in an
Expand All @@ -1187,6 +1213,11 @@ def def_method(m: TModule, method: Method, ready: ValueLike = C(1)):
Signal to indicate if the method is ready to be run. By
default it is `Const(1)`, so the method is always ready.
Assigned combinationally to the `ready` attribute.
ready_function: Optional[Callable[[Record], ValueLike]]
Function to instantiate a combinational circuit for each
method caller. It should take input arguments and return
if the method can be called with those arguments. By default
there is no function, so all arguments are accepted.

Examples
--------
Expand Down Expand Up @@ -1222,7 +1253,7 @@ def decorator(func: Callable[..., Optional[RecordDict]]):
out = Record.like(method.data_out)
ret_out = None

with method.body(m, ready=ready, out=out) as arg:
with method.body(m, ready=ready, out=out, ready_function=ready_function) as arg:
ret_out = method_def_helper(method, func, arg, **arg.fields)

if ret_out is not None:
Expand Down