Skip to content

Commit

Permalink
Refactor def_helper (kuznia-rdzeni/coreblocks#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored Oct 24, 2023
1 parent 63080d9 commit 2a11ca9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
35 changes: 24 additions & 11 deletions transactron/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import sys
from inspect import Parameter, signature
from typing import Optional, TypeAlias, TypeVar
from typing import Any, Concatenate, Optional, TypeAlias, TypeGuard, TypeVar
from collections.abc import Callable, Iterable, Mapping
from amaranth import *
from coreblocks.utils._typing import LayoutLike
Expand All @@ -15,11 +15,13 @@
"Graph",
"GraphCC",
"get_caller_class_name",
"def_helper",
"method_def_helper",
]


T = TypeVar("T")
U = TypeVar("U")


class Scheduler(Elaboratable):
Expand Down Expand Up @@ -122,24 +124,35 @@ def _graph_ccs(gr: ROGraph[T]) -> list[GraphCC[T]]:
MethodLayout: TypeAlias = LayoutLike


def method_def_helper(method, func: Callable[..., T], arg=None, /, **kwargs) -> T:
def has_first_param(func: Callable[..., T], name: str, tp: type[U]) -> TypeGuard[Callable[Concatenate[U, ...], T]]:
parameters = signature(func).parameters
return (
len(parameters) >= 1
and next(iter(parameters)) == name
and parameters[name].kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY}
and parameters[name].annotation in {Parameter.empty, tp}
)


def def_helper(description, func: Callable[..., T], tp: type[U], arg: U, /, **kwargs) -> T:
parameters = signature(func).parameters
kw_parameters = set(
n for n, p in parameters.items() if p.kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY}
)
if (
len(parameters) == 1
and "arg" in parameters
and parameters["arg"].kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY}
and parameters["arg"].annotation in {Parameter.empty, Record}
):
if arg is None:
arg = kwargs
if len(parameters) == 1 and has_first_param(func, "arg", tp):
return func(arg)
elif kw_parameters <= kwargs.keys():
return func(**kwargs)
else:
raise TypeError(f"Invalid method definition/mock for {method}: {func}")
raise TypeError(f"Invalid {description}: {func}")


def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T:
return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg)


def method_def_helper(method, func: Callable[..., T], arg: Record) -> T:
return def_helper(f"method definition for {method}", func, Record, arg, **arg.fields)


def get_caller_class_name(default: Optional[str] = None) -> tuple[Optional[Elaboratable], str]:
Expand Down
2 changes: 1 addition & 1 deletion transactron/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,7 @@ def decorator(func: Callable[..., Optional[RecordDict]]):
ret_out = None

with method.body(m, ready=ready, out=out) as arg:
ret_out = method_def_helper(method, func, arg, **arg.fields)
ret_out = method_def_helper(method, func, arg)

if ret_out is not None:
m.d.top_comb += assign(out, ret_out, fields=AssignType.ALL)
Expand Down

0 comments on commit 2a11ca9

Please sign in to comment.