Skip to content

Commit

Permalink
Refactor def_helper (#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored Oct 24, 2023
1 parent 3da0f5c commit a628256
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
4 changes: 2 additions & 2 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from transactron.core import SignalBundle, Method, TransactionModule
from transactron.lib import AdapterBase, AdapterTrans
from transactron._utils import method_def_helper
from transactron._utils import mock_def_helper
from coreblocks.utils import ValueLike, HasElaborate, HasDebugSignals, auto_debug_signals, LayoutLike, ModuleConnector
from .gtkw_extension import write_vcd_ext

Expand Down Expand Up @@ -363,7 +363,7 @@ def method_handle(
for _ in range(extra_settle_count + 1):
yield Settle()

ret_out = method_def_helper(self, function, **arg)
ret_out = mock_def_helper(self, function, arg)
yield from self.method_return(ret_out or {})
yield

Expand Down
35 changes: 24 additions & 11 deletions transactron/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import sys
from inspect import Parameter, signature
from typing import Optional, TypeAlias, TypeVar
from typing import Any, Concatenate, Optional, TypeAlias, TypeGuard, TypeVar
from collections.abc import Callable, Iterable, Mapping
from amaranth import *
from coreblocks.utils._typing import LayoutLike
Expand All @@ -15,11 +15,13 @@
"Graph",
"GraphCC",
"get_caller_class_name",
"def_helper",
"method_def_helper",
]


T = TypeVar("T")
U = TypeVar("U")


class Scheduler(Elaboratable):
Expand Down Expand Up @@ -122,24 +124,35 @@ def _graph_ccs(gr: ROGraph[T]) -> list[GraphCC[T]]:
MethodLayout: TypeAlias = LayoutLike


def method_def_helper(method, func: Callable[..., T], arg=None, /, **kwargs) -> T:
def has_first_param(func: Callable[..., T], name: str, tp: type[U]) -> TypeGuard[Callable[Concatenate[U, ...], T]]:
parameters = signature(func).parameters
return (
len(parameters) >= 1
and next(iter(parameters)) == name
and parameters[name].kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY}
and parameters[name].annotation in {Parameter.empty, tp}
)


def def_helper(description, func: Callable[..., T], tp: type[U], arg: U, /, **kwargs) -> T:
parameters = signature(func).parameters
kw_parameters = set(
n for n, p in parameters.items() if p.kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY}
)
if (
len(parameters) == 1
and "arg" in parameters
and parameters["arg"].kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY}
and parameters["arg"].annotation in {Parameter.empty, Record}
):
if arg is None:
arg = kwargs
if len(parameters) == 1 and has_first_param(func, "arg", tp):
return func(arg)
elif kw_parameters <= kwargs.keys():
return func(**kwargs)
else:
raise TypeError(f"Invalid method definition/mock for {method}: {func}")
raise TypeError(f"Invalid {description}: {func}")


def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T:
return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg)


def method_def_helper(method, func: Callable[..., T], arg: Record) -> T:
return def_helper(f"method definition for {method}", func, Record, arg, **arg.fields)


def get_caller_class_name(default: Optional[str] = None) -> tuple[Optional[Elaboratable], str]:
Expand Down
2 changes: 1 addition & 1 deletion transactron/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,7 @@ def decorator(func: Callable[..., Optional[RecordDict]]):
ret_out = None

with method.body(m, ready=ready, out=out) as arg:
ret_out = method_def_helper(method, func, arg, **arg.fields)
ret_out = method_def_helper(method, func, arg)

if ret_out is not None:
m.d.top_comb += assign(out, ret_out, fields=AssignType.ALL)
Expand Down

0 comments on commit a628256

Please sign in to comment.