Skip to content

Commit

Permalink
Shifting with placeholders
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk committed Nov 30, 2024
1 parent b3fd39b commit 12b7cdb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 18 deletions.
20 changes: 13 additions & 7 deletions transactron/utils/amaranth_ext/functions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from amaranth import *
from amaranth.hdl import ShapeCastable, ValueCastable
from amaranth.utils import bits_for, exact_log2
from amaranth.lib import data
from collections.abc import Iterable, Mapping

from amaranth_types.types import ValueLike
from transactron.utils._typing import SignalBundle

__all__ = [
"mod_incr",
"popcount",
"count_leading_zeros",
"count_trailing_zeros",
"flatten_signals",
]
__all__ = ["mod_incr", "popcount", "count_leading_zeros", "count_trailing_zeros", "flatten_signals", "shape_of"]


def mod_incr(sig: Value, mod: int) -> Value:
Expand Down Expand Up @@ -97,3 +94,12 @@ def flatten_signals(signals: SignalBundle) -> Iterable[Signal]:
yield from flatten_signals(signals[x])
else:
yield signals


def shape_of(value: ValueLike) -> Shape | ShapeCastable:
if isinstance(value, ValueCastable):
shape = value.shape()
assert isinstance(shape, (Shape, ShapeCastable))
return shape
else:
return Value.cast(value).shape()
69 changes: 58 additions & 11 deletions transactron/utils/amaranth_ext/shifter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from amaranth import *
from amaranth.hdl import ValueCastable, ShapeCastable
from amaranth.hdl import ValueCastable
from collections.abc import Sequence
from typing import TypeVar, overload
from typing import Optional, TypeVar, overload
from amaranth_types.types import ValueLike
from .functions import shape_of


__all__ = [
Expand All @@ -14,6 +15,8 @@
"rotate_left",
"generic_shift_vec_right",
"generic_shift_vec_left",
"shift_vec_right",
"shift_vec_left",
"rotate_vec_right",
"rotate_vec_left",
]
Expand All @@ -35,14 +38,18 @@ def generic_shift_left(value1: ValueLike, value2: ValueLike, offset: ValueLike)
return Cat(*reversed(generic_shift_right(Cat(*reversed(value2)), Cat(*reversed(value1)), offset)))


def shift_right(value: ValueLike, offset: ValueLike) -> Value:
def shift_right(value: ValueLike, offset: ValueLike, placeholder: ValueLike = 0) -> Value:
value = Value.cast(value)
return generic_shift_right(value, C(0, len(value)), offset)
placeholder = Value.cast(placeholder)
assert len(placeholder) == 1
return generic_shift_right(value, placeholder.replicate(len(value)), offset)


def shift_left(value: ValueLike, offset: ValueLike) -> Value:
def shift_left(value: ValueLike, offset: ValueLike, placeholder: ValueLike = 0) -> Value:
value = Value.cast(value)
return generic_shift_left(value, C(0, len(value)), offset)
placeholder = Value.cast(placeholder)
assert len(placeholder) == 1
return generic_shift_left(value, placeholder.replicate(len(value)), offset)


def rotate_right(value: ValueLike, offset: ValueLike) -> Value:
Expand All @@ -68,11 +75,7 @@ def generic_shift_vec_right(
def generic_shift_vec_right(
data1: Sequence[ValueLike | ValueCastable], data2: Sequence[ValueLike | ValueCastable], offset: ValueLike
) -> Sequence[Value | ValueCastable]:
if isinstance(data1[0], ValueCastable):
shape = data1[0].shape()
else:
shape = Value.cast(data1[0]).shape()
assert isinstance(shape, (Shape | ShapeCastable))
shape = shape_of(data1[0])

data1_values = [Value.cast(entry) for entry in data1]
data2_values = [Value.cast(entry) for entry in data2]
Expand Down Expand Up @@ -113,6 +116,50 @@ def generic_shift_vec_left(
return list(reversed(generic_shift_vec_right(list(reversed(data2)), list(reversed(data1)), offset)))


@overload
def shift_vec_right(
data: Sequence[_T_ValueCastable], offset: ValueLike, placeholder: Optional[_T_ValueCastable]
) -> Sequence[_T_ValueCastable]: ...


@overload
def shift_vec_right(
data: Sequence[ValueLike], offset: ValueLike, placeholder: Optional[ValueLike]
) -> Sequence[Value]: ...


def shift_vec_right(
data: Sequence[ValueLike | ValueCastable],
offset: ValueLike,
placeholder: Optional[ValueLike | ValueCastable] = None,
) -> Sequence[Value | ValueCastable]:
if placeholder is None:
placeholder = C(0, shape_of(data[0]))
return generic_shift_vec_right(data, [placeholder] * len(data), offset)


@overload
def shift_vec_left(
data: Sequence[_T_ValueCastable], offset: ValueLike, placeholder: Optional[_T_ValueCastable]
) -> Sequence[_T_ValueCastable]: ...


@overload
def shift_vec_left(
data: Sequence[ValueLike], offset: ValueLike, placeholder: Optional[ValueLike]
) -> Sequence[Value]: ...


def shift_vec_left(
data: Sequence[ValueLike | ValueCastable],
offset: ValueLike,
placeholder: Optional[ValueLike | ValueCastable] = None,
) -> Sequence[Value | ValueCastable]:
if placeholder is None:
placeholder = C(0, shape_of(data[0]))
return generic_shift_vec_left(data, [placeholder] * len(data), offset)


@overload
def rotate_vec_right(data: Sequence[_T_ValueCastable], offset: ValueLike) -> Sequence[_T_ValueCastable]: ...

Expand Down

0 comments on commit 12b7cdb

Please sign in to comment.