diff --git a/transactron/core.py b/transactron/core.py index a4b3f3b..ba0f115 100644 --- a/transactron/core.py +++ b/transactron/core.py @@ -75,15 +75,17 @@ class MethodMap: def __init__(self, transactions: Iterable["Transaction"]): self.methods_by_transaction = dict[Transaction, list[Method]]() self.transactions_by_method = defaultdict[Method, list[Transaction]](list) + self.readiness_by_method_and_transaction = dict[tuple[Transaction, Method], ValueLike]() def rec(transaction: Transaction, source: TransactionBase): - for method in source.method_uses.keys(): + for method, (arg_rec, _) in source.method_uses.items(): if not method.defined: raise RuntimeError(f"Trying to use method '{method.name}' which is not defined yet") if method in self.methods_by_transaction[transaction]: raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction") self.methods_by_transaction[transaction].append(method) self.transactions_by_method[method].append(transaction) + self.readiness_by_method_and_transaction[(transaction, method)] = method._validate_arguments(arg_rec) rec(transaction, method) for transaction in transactions: @@ -139,7 +141,10 @@ def eager_deterministic_cc_scheduler( ccl = list(cc) ccl.sort(key=lambda transaction: porder[transaction]) for k, transaction in enumerate(ccl): - ready = [method.ready for method in method_map.methods_by_transaction[transaction]] + ready = [ + method_map.readiness_by_method_and_transaction[(transaction, method)] + for method in method_map.methods_by_transaction[transaction] + ] runnable = Cat(ready).all() conflicts = [ccl[j].grant for j in range(k) if ccl[j] in gr[transaction]] noconflict = ~Cat(conflicts).any() @@ -175,11 +180,11 @@ def trivial_roundrobin_cc_scheduler( sched = Scheduler(len(cc)) m.submodules.scheduler = sched for k, transaction in enumerate(cc): - methods = method_map.methods_by_transaction[transaction] - ready = Signal(len(methods)) - for n, method in enumerate(methods): - m.d.comb += ready[n].eq(method.ready) - runnable = ready.all() + ready = [ + method_map.readiness_by_method_and_transaction[(transaction, method)] + for method in method_map.methods_by_transaction[transaction] + ] + runnable = Cat(ready).all() m.d.comb += sched.requests[k].eq(transaction.request & runnable) m.d.comb += transaction.grant.eq(sched.grant[k] & sched.valid) return m @@ -689,13 +694,13 @@ class TransactionBase(Owned, Protocol): def_order: int defined: bool = False name: str - method_uses: dict["Method", Tuple[ValueLike, ValueLike]] + method_uses: dict["Method", Tuple[Record, ValueLike]] relations: list[RelationBase] simultaneous_list: list[TransactionOrMethod] independent_list: list[TransactionOrMethod] def __init__(self): - self.method_uses: dict["Method", Tuple[ValueLike, ValueLike]] = dict() + self.method_uses: dict["Method", Tuple[Record, ValueLike]] = dict() self.relations: list[RelationBase] = [] self.simultaneous_list: list[TransactionOrMethod] = [] self.independent_list: list[TransactionOrMethod] = [] @@ -735,7 +740,7 @@ def schedule_before(self, end: TransactionOrMethod) -> None: RelationBase(end=end, priority=Priority.LEFT, conflict=False, silence_warning=self.owner != end.owner) ) - def use_method(self, method: "Method", arg: ValueLike, enable: ValueLike): + def use_method(self, method: "Method", arg: Record, enable: ValueLike): if method in self.method_uses: raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction '{self.name}'") self.method_uses[method] = (arg, enable) @@ -998,6 +1003,7 @@ def __init__( self.data_out = Record(o) self.nonexclusive = nonexclusive self.single_caller = single_caller + self.validate_arguments: Optional[Callable[..., ValueLike]] = None if nonexclusive: assert len(self.data_in) == 0 @@ -1041,7 +1047,14 @@ def _(arg): return method(m, arg) @contextmanager - def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0)) -> Iterator[Record]: + def body( + self, + m: TModule, + *, + ready: ValueLike = C(1), + out: ValueLike = C(0, 0), + validate_arguments: Optional[Callable[..., ValueLike]] = None, + ) -> Iterator[Record]: """Define method body The `body` context manager can be used to define the actions @@ -1064,6 +1077,12 @@ def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0)) Data generated by the `Method`, which will be passed to the caller (a `Transaction` or another `Method`). Assigned combinationally to the `data_out` attribute. + validate_arguments: Optional[Callable[..., ValueLike]] + Function that takes input arguments used to call the method + and checks whether the method can be called with those arguments. + It instantiates a combinational circuit for each + method caller. By default, there is no function, so all arguments + are accepted. Returns ------- @@ -1085,6 +1104,7 @@ def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0)) if self.defined: raise RuntimeError(f"Method '{self.name}' already defined") self.def_order = next(TransactionBase.def_counter) + self.validate_arguments = validate_arguments try: m.d.av_comb += self.ready.eq(ready) @@ -1095,6 +1115,11 @@ def body(self, m: TModule, *, ready: ValueLike = C(1), out: ValueLike = C(0, 0)) finally: self.defined = True + def _validate_arguments(self, arg_rec: Record) -> ValueLike: + if self.validate_arguments is not None: + return self.ready & method_def_helper(self, self.validate_arguments, arg_rec) + return self.ready + def __call__( self, m: TModule, arg: Optional[RecordDict] = None, enable: ValueLike = C(1), /, **kwargs: RecordDict ) -> Record: @@ -1166,7 +1191,12 @@ def debug_signals(self) -> SignalBundle: return [self.ready, self.run, self.data_in, self.data_out] -def def_method(m: TModule, method: Method, ready: ValueLike = C(1)): +def def_method( + m: TModule, + method: Method, + ready: ValueLike = C(1), + validate_arguments: Optional[Callable[..., ValueLike]] = None, +): """Define a method. This decorator allows to define transactional methods in an @@ -1191,6 +1221,12 @@ def def_method(m: TModule, method: Method, ready: ValueLike = C(1)): Signal to indicate if the method is ready to be run. By default it is `Const(1)`, so the method is always ready. Assigned combinationally to the `ready` attribute. + validate_arguments: Optional[Callable[..., ValueLike]] + Function that takes input arguments used to call the method + and checks whether the method can be called with those arguments. + It instantiates a combinational circuit for each + method caller. By default, there is no function, so all arguments + are accepted. Examples -------- @@ -1226,7 +1262,7 @@ def decorator(func: Callable[..., Optional[RecordDict]]): out = Record.like(method.data_out) ret_out = None - with method.body(m, ready=ready, out=out) as arg: + with method.body(m, ready=ready, out=out, validate_arguments=validate_arguments) as arg: ret_out = method_def_helper(method, func, arg) if ret_out is not None: