diff --git a/transactron/_utils.py b/transactron/_utils.py index f86b0d6..138c722 100644 --- a/transactron/_utils.py +++ b/transactron/_utils.py @@ -1,7 +1,7 @@ import itertools import sys from inspect import Parameter, signature -from typing import Optional, TypeAlias, TypeVar +from typing import Any, Concatenate, Optional, TypeAlias, TypeGuard, TypeVar from collections.abc import Callable, Iterable, Mapping from amaranth import * from coreblocks.utils._typing import LayoutLike @@ -15,11 +15,13 @@ "Graph", "GraphCC", "get_caller_class_name", + "def_helper", "method_def_helper", ] T = TypeVar("T") +U = TypeVar("U") class Scheduler(Elaboratable): @@ -122,24 +124,35 @@ def _graph_ccs(gr: ROGraph[T]) -> list[GraphCC[T]]: MethodLayout: TypeAlias = LayoutLike -def method_def_helper(method, func: Callable[..., T], arg=None, /, **kwargs) -> T: +def has_first_param(func: Callable[..., T], name: str, tp: type[U]) -> TypeGuard[Callable[Concatenate[U, ...], T]]: + parameters = signature(func).parameters + return ( + len(parameters) >= 1 + and next(iter(parameters)) == name + and parameters[name].kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY} + and parameters[name].annotation in {Parameter.empty, tp} + ) + + +def def_helper(description, func: Callable[..., T], tp: type[U], arg: U, /, **kwargs) -> T: parameters = signature(func).parameters kw_parameters = set( n for n, p in parameters.items() if p.kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY} ) - if ( - len(parameters) == 1 - and "arg" in parameters - and parameters["arg"].kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY} - and parameters["arg"].annotation in {Parameter.empty, Record} - ): - if arg is None: - arg = kwargs + if len(parameters) == 1 and has_first_param(func, "arg", tp): return func(arg) elif kw_parameters <= kwargs.keys(): return func(**kwargs) else: - raise TypeError(f"Invalid method definition/mock for {method}: {func}") + raise TypeError(f"Invalid {description}: {func}") + + +def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T: + return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg) + + +def method_def_helper(method, func: Callable[..., T], arg: Record) -> T: + return def_helper(f"method definition for {method}", func, Record, arg, **arg.fields) def get_caller_class_name(default: Optional[str] = None) -> tuple[Optional[Elaboratable], str]: diff --git a/transactron/core.py b/transactron/core.py index 745b98b..035b3b5 100644 --- a/transactron/core.py +++ b/transactron/core.py @@ -1223,7 +1223,7 @@ def decorator(func: Callable[..., Optional[RecordDict]]): ret_out = None with method.body(m, ready=ready, out=out) as arg: - ret_out = method_def_helper(method, func, arg, **arg.fields) + ret_out = method_def_helper(method, func, arg) if ret_out is not None: m.d.top_comb += assign(out, ret_out, fields=AssignType.ALL)