diff --git a/transactron/utils/amaranth_ext/functions.py b/transactron/utils/amaranth_ext/functions.py index d09c7b5..8a1f03d 100644 --- a/transactron/utils/amaranth_ext/functions.py +++ b/transactron/utils/amaranth_ext/functions.py @@ -1,16 +1,13 @@ from amaranth import * +from amaranth.hdl import ShapeCastable, ValueCastable from amaranth.utils import bits_for, exact_log2 from amaranth.lib import data from collections.abc import Iterable, Mapping + +from amaranth_types.types import ValueLike from transactron.utils._typing import SignalBundle -__all__ = [ - "mod_incr", - "popcount", - "count_leading_zeros", - "count_trailing_zeros", - "flatten_signals", -] +__all__ = ["mod_incr", "popcount", "count_leading_zeros", "count_trailing_zeros", "flatten_signals", "shape_of"] def mod_incr(sig: Value, mod: int) -> Value: @@ -97,3 +94,12 @@ def flatten_signals(signals: SignalBundle) -> Iterable[Signal]: yield from flatten_signals(signals[x]) else: yield signals + + +def shape_of(value: ValueLike) -> Shape | ShapeCastable: + if isinstance(value, ValueCastable): + shape = value.shape() + assert isinstance(shape, (Shape, ShapeCastable)) + return shape + else: + return Value.cast(value).shape() diff --git a/transactron/utils/amaranth_ext/shifter.py b/transactron/utils/amaranth_ext/shifter.py index 15afa68..3fd16b1 100644 --- a/transactron/utils/amaranth_ext/shifter.py +++ b/transactron/utils/amaranth_ext/shifter.py @@ -1,8 +1,9 @@ from amaranth import * -from amaranth.hdl import ValueCastable, ShapeCastable +from amaranth.hdl import ValueCastable from collections.abc import Sequence -from typing import TypeVar, overload +from typing import Optional, TypeVar, overload from amaranth_types.types import ValueLike +from .functions import shape_of __all__ = [ @@ -14,6 +15,8 @@ "rotate_left", "generic_shift_vec_right", "generic_shift_vec_left", + "shift_vec_right", + "shift_vec_left", "rotate_vec_right", "rotate_vec_left", ] @@ -35,14 +38,18 @@ def generic_shift_left(value1: ValueLike, value2: ValueLike, offset: ValueLike) return Cat(*reversed(generic_shift_right(Cat(*reversed(value2)), Cat(*reversed(value1)), offset))) -def shift_right(value: ValueLike, offset: ValueLike) -> Value: +def shift_right(value: ValueLike, offset: ValueLike, placeholder: ValueLike = 0) -> Value: value = Value.cast(value) - return generic_shift_right(value, C(0, len(value)), offset) + placeholder = Value.cast(placeholder) + assert len(placeholder) == 1 + return generic_shift_right(value, placeholder.replicate(len(value)), offset) -def shift_left(value: ValueLike, offset: ValueLike) -> Value: +def shift_left(value: ValueLike, offset: ValueLike, placeholder: ValueLike = 0) -> Value: value = Value.cast(value) - return generic_shift_left(value, C(0, len(value)), offset) + placeholder = Value.cast(placeholder) + assert len(placeholder) == 1 + return generic_shift_left(value, placeholder.replicate(len(value)), offset) def rotate_right(value: ValueLike, offset: ValueLike) -> Value: @@ -68,11 +75,7 @@ def generic_shift_vec_right( def generic_shift_vec_right( data1: Sequence[ValueLike | ValueCastable], data2: Sequence[ValueLike | ValueCastable], offset: ValueLike ) -> Sequence[Value | ValueCastable]: - if isinstance(data1[0], ValueCastable): - shape = data1[0].shape() - else: - shape = Value.cast(data1[0]).shape() - assert isinstance(shape, (Shape | ShapeCastable)) + shape = shape_of(data1[0]) data1_values = [Value.cast(entry) for entry in data1] data2_values = [Value.cast(entry) for entry in data2] @@ -113,6 +116,50 @@ def generic_shift_vec_left( return list(reversed(generic_shift_vec_right(list(reversed(data2)), list(reversed(data1)), offset))) +@overload +def shift_vec_right( + data: Sequence[_T_ValueCastable], offset: ValueLike, placeholder: Optional[_T_ValueCastable] +) -> Sequence[_T_ValueCastable]: ... + + +@overload +def shift_vec_right( + data: Sequence[ValueLike], offset: ValueLike, placeholder: Optional[ValueLike] +) -> Sequence[Value]: ... + + +def shift_vec_right( + data: Sequence[ValueLike | ValueCastable], + offset: ValueLike, + placeholder: Optional[ValueLike | ValueCastable] = None, +) -> Sequence[Value | ValueCastable]: + if placeholder is None: + placeholder = C(0, shape_of(data[0])) + return generic_shift_vec_right(data, [placeholder] * len(data), offset) + + +@overload +def shift_vec_left( + data: Sequence[_T_ValueCastable], offset: ValueLike, placeholder: Optional[_T_ValueCastable] +) -> Sequence[_T_ValueCastable]: ... + + +@overload +def shift_vec_left( + data: Sequence[ValueLike], offset: ValueLike, placeholder: Optional[ValueLike] +) -> Sequence[Value]: ... + + +def shift_vec_left( + data: Sequence[ValueLike | ValueCastable], + offset: ValueLike, + placeholder: Optional[ValueLike | ValueCastable] = None, +) -> Sequence[Value | ValueCastable]: + if placeholder is None: + placeholder = C(0, shape_of(data[0])) + return generic_shift_vec_left(data, [placeholder] * len(data), offset) + + @overload def rotate_vec_right(data: Sequence[_T_ValueCastable], offset: ValueLike) -> Sequence[_T_ValueCastable]: ...