Skip to content

Commit

Permalink
Ring priority encoder (kuznia-rdzeni/coreblocks#703)
Browse files Browse the repository at this point in the history
  • Loading branch information
lekcyjna123 authored May 27, 2024
1 parent b509086 commit 1ab2585
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 35 deletions.
113 changes: 78 additions & 35 deletions test/utils/test_amaranth_ext.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,75 @@
from transactron.testing import *
import random
from transactron.utils.amaranth_ext import MultiPriorityEncoder


class TestMultiPriorityEncoder(TestCaseWithSimulator):
def get_expected(self, input):
places = []
for i in range(self.input_width):
if input % 2:
places.append(i)
input //= 2
places += [None] * self.output_count
return places

def process(self):
for _ in range(self.test_number):
input = random.randrange(2**self.input_width)
yield self.circ.input.eq(input)
yield Settle()
expected_output = self.get_expected(input)
for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids):
if ex is None:
assert (yield valid) == 0
else:
assert (yield valid) == 1
assert (yield real) == ex
yield Delay(1e-7)
from transactron.utils.amaranth_ext import MultiPriorityEncoder, RingMultiPriorityEncoder


def get_expected_multi(input_width, output_count, input, *args):
places = []
for i in range(input_width):
if input % 2:
places.append(i)
input //= 2
places += [None] * output_count
return places


def get_expected_ring(input_width, output_count, input, first, last):
places = []
input = (input << input_width) + input
if last < first:
last += input_width
for i in range(2 * input_width):
if i >= first and i < last and input % 2:
places.append(i % input_width)
input //= 2
places += [None] * output_count
return places


@pytest.mark.parametrize(
"test_class, verif_f",
[(MultiPriorityEncoder, get_expected_multi), (RingMultiPriorityEncoder, get_expected_ring)],
ids=["MultiPriorityEncoder", "RingMultiPriorityEncoder"],
)
class TestPriorityEncoder(TestCaseWithSimulator):
def process(self, get_expected):
def f():
for _ in range(self.test_number):
input = random.randrange(2**self.input_width)
first = random.randrange(self.input_width)
last = random.randrange(self.input_width)
yield self.circ.input.eq(input)
try:
yield self.circ.first.eq(first)
yield self.circ.last.eq(last)
except AttributeError:
pass
yield Settle()
expected_output = get_expected(self.input_width, self.output_count, input, first, last)
for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids):
if ex is None:
assert (yield valid) == 0
else:
assert (yield valid) == 1
assert (yield real) == ex
yield Delay(1e-7)

return f

@pytest.mark.parametrize("input_width", [1, 5, 16, 23, 24])
@pytest.mark.parametrize("output_count", [1, 3, 4])
def test_random(self, input_width, output_count):
def test_random(self, test_class, verif_f, input_width, output_count):
random.seed(input_width + output_count)
self.test_number = 50
self.input_width = input_width
self.output_count = output_count
self.circ = MultiPriorityEncoder(self.input_width, self.output_count)
self.circ = test_class(self.input_width, self.output_count)

with self.run_simulation(self.circ) as sim:
sim.add_process(self.process)
sim.add_process(self.process(verif_f))

@pytest.mark.parametrize("name", ["prio_encoder", None])
def test_static_create_simple(self, name):
def test_static_create_simple(self, test_class, verif_f, name):
random.seed(14)
self.test_number = 50
self.input_width = 7
Expand All @@ -49,13 +78,20 @@ def test_static_create_simple(self, name):
class DUT(Elaboratable):
def __init__(self, input_width, output_count, name):
self.input = Signal(input_width)
self.first = Signal(range(input_width))
self.last = Signal(range(input_width))
self.output_count = output_count
self.input_width = input_width
self.name = name

def elaborate(self, platform):
m = Module()
out, val = MultiPriorityEncoder.create_simple(m, self.input_width, self.input, name=self.name)
if test_class == MultiPriorityEncoder:
out, val = test_class.create_simple(m, self.input_width, self.input, name=self.name)
else:
out, val = test_class.create_simple(
m, self.input_width, self.input, self.first, self.last, name=self.name
)
# Save as a list to use common interface in testing
self.outputs = [out]
self.valids = [val]
Expand All @@ -64,10 +100,10 @@ def elaborate(self, platform):
self.circ = DUT(self.input_width, self.output_count, name)

with self.run_simulation(self.circ) as sim:
sim.add_process(self.process)
sim.add_process(self.process(verif_f))

@pytest.mark.parametrize("name", ["prio_encoder", None])
def test_static_create(self, name):
def test_static_create(self, test_class, verif_f, name):
random.seed(14)
self.test_number = 50
self.input_width = 7
Expand All @@ -76,17 +112,24 @@ def test_static_create(self, name):
class DUT(Elaboratable):
def __init__(self, input_width, output_count, name):
self.input = Signal(input_width)
self.first = Signal(range(input_width))
self.last = Signal(range(input_width))
self.output_count = output_count
self.input_width = input_width
self.name = name

def elaborate(self, platform):
m = Module()
out = MultiPriorityEncoder.create(m, self.input_width, self.input, self.output_count, name=self.name)
if test_class == MultiPriorityEncoder:
out = test_class.create(m, self.input_width, self.input, self.output_count, name=self.name)
else:
out = test_class.create(
m, self.input_width, self.input, self.first, self.last, self.output_count, name=self.name
)
self.outputs, self.valids = list(zip(*out))
return m

self.circ = DUT(self.input_width, self.output_count, name)

with self.run_simulation(self.circ) as sim:
sim.add_process(self.process)
sim.add_process(self.process(verif_f))
153 changes: 153 additions & 0 deletions transactron/utils/amaranth_ext/elaboratables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"Scheduler",
"RoundRobin",
"MultiPriorityEncoder",
"RingMultiPriorityEncoder",
]


Expand Down Expand Up @@ -377,3 +378,155 @@ def elaborate(self, platform):
m.d.comb += self.valids[k].eq(level_valids[k])

return m


class RingMultiPriorityEncoder(Elaboratable):
"""Priority encoder with one or more outputs and flexible start
This is an extension of the `MultiPriorityEncoder` that supports
flexible start and end indexes. In the standard `MultiPriorityEncoder`
the first bit is always at position 0 and the last is the last bit of
the input signal. In this extended implementation, both can be
selected at runtime.
This implementation is intended for selection from the circular buffers,
so if `last < first` the encoder will first select bits from
[first, input_width) and then from [0, last).
Attributes
----------
input_width : int
Width of the input signal
outputs_count : int
Number of outputs to generate at once.
input : Signal, in
Signal with 1 on `i`-th bit if `i` can be selected by encoder
first : Signal, in
Index of the first bit in the `input`. Inclusive.
last : Signal, out
Index of the last bit in the `input`. Exclusive.
outputs : list[Signal], out
Signals with selected indicies, sorted in ascending order,
if the number of ready signals is less than `outputs_count`
then valid signals are at the beginning of the list.
valids : list[Signal], out
One bit for each output signal, indicating whether the output is valid or not.
"""

def __init__(self, input_width: int, outputs_count: int):
self.input_width = input_width
self.outputs_count = outputs_count

self.input = Signal(self.input_width)
self.first = Signal(range(self.input_width))
self.last = Signal(range(self.input_width))
self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)]
self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)]

@staticmethod
def create(
m: Module,
input_width: int,
input: ValueLike,
first: ValueLike,
last: ValueLike,
outputs_count: int = 1,
name: Optional[str] = None,
) -> list[tuple[Signal, Signal]]:
"""Syntax sugar for creating RingMultiPriorityEncoder
This static method allows to use RingMultiPriorityEncoder in a more functional
way. Instead of creating the instance manually, connecting all the signals and
adding a submodule, you can call this function to do it automatically.
This function is equivalent to:
.. highlight:: python
.. code-block:: python
m.submodules += prio_encoder = RingMultiPriorityEncoder(input_width, outputs_count)
m.d.comb += prio_encoder.input.eq(one_hot_singal)
m.d.comb += prio_encoder.first.eq(first)
m.d.comb += prio_encoder.last.eq(last)
idx = prio_encoder.outputs
valid = prio.encoder.valids
Parameters
----------
m: Module
Module to add the RingMultiPriorityEncoder to.
input_width : int
Width of the one hot signal.
input : ValueLike
The one hot signal to decode.
first : ValueLike
Index of the first bit in the `input`. Inclusive.
last : ValueLike
Index of the last bit in the `input`. Exclusive.
outputs_count : int
Number of different decoder outputs to generate at once. Default: 1.
name : Optional[str]
Name to use when adding RingMultiPriorityEncoder to submodules.
If None, it will be added as an anonymous submodule. The given name
can not be used in a submodule that has already been added. Default: None.
Returns
-------
return : list[tuple[Signal, Signal]]
Returns a list with len equal to outputs_count. Each tuple contains
a pair of decoded index on the first position and a valid signal
on the second position.
"""
prio_encoder = RingMultiPriorityEncoder(input_width, outputs_count)
if name is None:
m.submodules += prio_encoder
else:
try:
getattr(m.submodules, name)
raise ValueError(
f"Name: {name} is already in use, so RingMultiPriorityEncoder can not be added with it."
)
except AttributeError:
setattr(m.submodules, name, prio_encoder)
m.d.comb += prio_encoder.input.eq(input)
m.d.comb += prio_encoder.first.eq(first)
m.d.comb += prio_encoder.last.eq(last)
return list(zip(prio_encoder.outputs, prio_encoder.valids))

@staticmethod
def create_simple(
m: Module, input_width: int, input: ValueLike, first: ValueLike, last: ValueLike, name: Optional[str] = None
) -> tuple[Signal, Signal]:
"""Syntax sugar for creating RingMultiPriorityEncoder
This is the same as `create` function, but with `outputs_count` hardcoded to 1.
"""
lst = RingMultiPriorityEncoder.create(m, input_width, input, first, last, outputs_count=1, name=name)
return lst[0]

def elaborate(self, platform):
m = Module()
double_input = Signal(2 * self.input_width)
m.d.comb += double_input.eq(Cat(self.input, self.input))

last_corrected = Signal(range(self.input_width * 2))
with m.If(self.first > self.last):
m.d.comb += last_corrected.eq(self.input_width + self.last)
with m.Else():
m.d.comb += last_corrected.eq(self.last)

mask = Signal.like(double_input)
m.d.comb += mask.eq((1 << last_corrected) - 1)

multi_enc_input = (double_input & mask) >> self.first

m.submodules.multi_enc = multi_enc = MultiPriorityEncoder(self.input_width, self.outputs_count)
m.d.comb += multi_enc.input.eq(multi_enc_input)
for k in range(self.outputs_count):
moved_out = Signal(range(2 * self.input_width))
m.d.comb += moved_out.eq(multi_enc.outputs[k] + self.first)
corrected_out = Mux(moved_out >= self.input_width, moved_out - self.input_width, moved_out)

m.d.comb += self.outputs[k].eq(corrected_out)
m.d.comb += self.valids[k].eq(multi_enc.valids[k])
return m

0 comments on commit 1ab2585

Please sign in to comment.