diff --git a/test/utils/test_amaranth_ext.py b/test/utils/test_amaranth_ext.py index 7b7bd46..7943ccb 100644 --- a/test/utils/test_amaranth_ext.py +++ b/test/utils/test_amaranth_ext.py @@ -1,46 +1,75 @@ from transactron.testing import * import random -from transactron.utils.amaranth_ext import MultiPriorityEncoder - - -class TestMultiPriorityEncoder(TestCaseWithSimulator): - def get_expected(self, input): - places = [] - for i in range(self.input_width): - if input % 2: - places.append(i) - input //= 2 - places += [None] * self.output_count - return places - - def process(self): - for _ in range(self.test_number): - input = random.randrange(2**self.input_width) - yield self.circ.input.eq(input) - yield Settle() - expected_output = self.get_expected(input) - for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids): - if ex is None: - assert (yield valid) == 0 - else: - assert (yield valid) == 1 - assert (yield real) == ex - yield Delay(1e-7) +from transactron.utils.amaranth_ext import MultiPriorityEncoder, RingMultiPriorityEncoder + + +def get_expected_multi(input_width, output_count, input, *args): + places = [] + for i in range(input_width): + if input % 2: + places.append(i) + input //= 2 + places += [None] * output_count + return places + + +def get_expected_ring(input_width, output_count, input, first, last): + places = [] + input = (input << input_width) + input + if last < first: + last += input_width + for i in range(2 * input_width): + if i >= first and i < last and input % 2: + places.append(i % input_width) + input //= 2 + places += [None] * output_count + return places + + +@pytest.mark.parametrize( + "test_class, verif_f", + [(MultiPriorityEncoder, get_expected_multi), (RingMultiPriorityEncoder, get_expected_ring)], + ids=["MultiPriorityEncoder", "RingMultiPriorityEncoder"], +) +class TestPriorityEncoder(TestCaseWithSimulator): + def process(self, get_expected): + def f(): + for _ in range(self.test_number): + input = random.randrange(2**self.input_width) + first = random.randrange(self.input_width) + last = random.randrange(self.input_width) + yield self.circ.input.eq(input) + try: + yield self.circ.first.eq(first) + yield self.circ.last.eq(last) + except AttributeError: + pass + yield Settle() + expected_output = get_expected(self.input_width, self.output_count, input, first, last) + for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids): + if ex is None: + assert (yield valid) == 0 + else: + assert (yield valid) == 1 + assert (yield real) == ex + yield Delay(1e-7) + + return f @pytest.mark.parametrize("input_width", [1, 5, 16, 23, 24]) @pytest.mark.parametrize("output_count", [1, 3, 4]) - def test_random(self, input_width, output_count): + def test_random(self, test_class, verif_f, input_width, output_count): random.seed(input_width + output_count) self.test_number = 50 self.input_width = input_width self.output_count = output_count - self.circ = MultiPriorityEncoder(self.input_width, self.output_count) + self.circ = test_class(self.input_width, self.output_count) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process) + sim.add_process(self.process(verif_f)) @pytest.mark.parametrize("name", ["prio_encoder", None]) - def test_static_create_simple(self, name): + def test_static_create_simple(self, test_class, verif_f, name): random.seed(14) self.test_number = 50 self.input_width = 7 @@ -49,13 +78,20 @@ def test_static_create_simple(self, name): class DUT(Elaboratable): def __init__(self, input_width, output_count, name): self.input = Signal(input_width) + self.first = Signal(range(input_width)) + self.last = Signal(range(input_width)) self.output_count = output_count self.input_width = input_width self.name = name def elaborate(self, platform): m = Module() - out, val = MultiPriorityEncoder.create_simple(m, self.input_width, self.input, name=self.name) + if test_class == MultiPriorityEncoder: + out, val = test_class.create_simple(m, self.input_width, self.input, name=self.name) + else: + out, val = test_class.create_simple( + m, self.input_width, self.input, self.first, self.last, name=self.name + ) # Save as a list to use common interface in testing self.outputs = [out] self.valids = [val] @@ -64,10 +100,10 @@ def elaborate(self, platform): self.circ = DUT(self.input_width, self.output_count, name) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process) + sim.add_process(self.process(verif_f)) @pytest.mark.parametrize("name", ["prio_encoder", None]) - def test_static_create(self, name): + def test_static_create(self, test_class, verif_f, name): random.seed(14) self.test_number = 50 self.input_width = 7 @@ -76,17 +112,24 @@ def test_static_create(self, name): class DUT(Elaboratable): def __init__(self, input_width, output_count, name): self.input = Signal(input_width) + self.first = Signal(range(input_width)) + self.last = Signal(range(input_width)) self.output_count = output_count self.input_width = input_width self.name = name def elaborate(self, platform): m = Module() - out = MultiPriorityEncoder.create(m, self.input_width, self.input, self.output_count, name=self.name) + if test_class == MultiPriorityEncoder: + out = test_class.create(m, self.input_width, self.input, self.output_count, name=self.name) + else: + out = test_class.create( + m, self.input_width, self.input, self.first, self.last, self.output_count, name=self.name + ) self.outputs, self.valids = list(zip(*out)) return m self.circ = DUT(self.input_width, self.output_count, name) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process) + sim.add_process(self.process(verif_f)) diff --git a/transactron/utils/amaranth_ext/elaboratables.py b/transactron/utils/amaranth_ext/elaboratables.py index 6048bc7..7feed52 100644 --- a/transactron/utils/amaranth_ext/elaboratables.py +++ b/transactron/utils/amaranth_ext/elaboratables.py @@ -12,6 +12,7 @@ "Scheduler", "RoundRobin", "MultiPriorityEncoder", + "RingMultiPriorityEncoder", ] @@ -377,3 +378,155 @@ def elaborate(self, platform): m.d.comb += self.valids[k].eq(level_valids[k]) return m + + +class RingMultiPriorityEncoder(Elaboratable): + """Priority encoder with one or more outputs and flexible start + + This is an extension of the `MultiPriorityEncoder` that supports + flexible start and end indexes. In the standard `MultiPriorityEncoder` + the first bit is always at position 0 and the last is the last bit of + the input signal. In this extended implementation, both can be + selected at runtime. + + This implementation is intended for selection from the circular buffers, + so if `last < first` the encoder will first select bits from + [first, input_width) and then from [0, last). + + Attributes + ---------- + input_width : int + Width of the input signal + outputs_count : int + Number of outputs to generate at once. + input : Signal, in + Signal with 1 on `i`-th bit if `i` can be selected by encoder + first : Signal, in + Index of the first bit in the `input`. Inclusive. + last : Signal, out + Index of the last bit in the `input`. Exclusive. + outputs : list[Signal], out + Signals with selected indicies, sorted in ascending order, + if the number of ready signals is less than `outputs_count` + then valid signals are at the beginning of the list. + valids : list[Signal], out + One bit for each output signal, indicating whether the output is valid or not. + """ + + def __init__(self, input_width: int, outputs_count: int): + self.input_width = input_width + self.outputs_count = outputs_count + + self.input = Signal(self.input_width) + self.first = Signal(range(self.input_width)) + self.last = Signal(range(self.input_width)) + self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)] + self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)] + + @staticmethod + def create( + m: Module, + input_width: int, + input: ValueLike, + first: ValueLike, + last: ValueLike, + outputs_count: int = 1, + name: Optional[str] = None, + ) -> list[tuple[Signal, Signal]]: + """Syntax sugar for creating RingMultiPriorityEncoder + + This static method allows to use RingMultiPriorityEncoder in a more functional + way. Instead of creating the instance manually, connecting all the signals and + adding a submodule, you can call this function to do it automatically. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + m.submodules += prio_encoder = RingMultiPriorityEncoder(input_width, outputs_count) + m.d.comb += prio_encoder.input.eq(one_hot_singal) + m.d.comb += prio_encoder.first.eq(first) + m.d.comb += prio_encoder.last.eq(last) + idx = prio_encoder.outputs + valid = prio.encoder.valids + + Parameters + ---------- + m: Module + Module to add the RingMultiPriorityEncoder to. + input_width : int + Width of the one hot signal. + input : ValueLike + The one hot signal to decode. + first : ValueLike + Index of the first bit in the `input`. Inclusive. + last : ValueLike + Index of the last bit in the `input`. Exclusive. + outputs_count : int + Number of different decoder outputs to generate at once. Default: 1. + name : Optional[str] + Name to use when adding RingMultiPriorityEncoder to submodules. + If None, it will be added as an anonymous submodule. The given name + can not be used in a submodule that has already been added. Default: None. + + Returns + ------- + return : list[tuple[Signal, Signal]] + Returns a list with len equal to outputs_count. Each tuple contains + a pair of decoded index on the first position and a valid signal + on the second position. + """ + prio_encoder = RingMultiPriorityEncoder(input_width, outputs_count) + if name is None: + m.submodules += prio_encoder + else: + try: + getattr(m.submodules, name) + raise ValueError( + f"Name: {name} is already in use, so RingMultiPriorityEncoder can not be added with it." + ) + except AttributeError: + setattr(m.submodules, name, prio_encoder) + m.d.comb += prio_encoder.input.eq(input) + m.d.comb += prio_encoder.first.eq(first) + m.d.comb += prio_encoder.last.eq(last) + return list(zip(prio_encoder.outputs, prio_encoder.valids)) + + @staticmethod + def create_simple( + m: Module, input_width: int, input: ValueLike, first: ValueLike, last: ValueLike, name: Optional[str] = None + ) -> tuple[Signal, Signal]: + """Syntax sugar for creating RingMultiPriorityEncoder + + This is the same as `create` function, but with `outputs_count` hardcoded to 1. + """ + lst = RingMultiPriorityEncoder.create(m, input_width, input, first, last, outputs_count=1, name=name) + return lst[0] + + def elaborate(self, platform): + m = Module() + double_input = Signal(2 * self.input_width) + m.d.comb += double_input.eq(Cat(self.input, self.input)) + + last_corrected = Signal(range(self.input_width * 2)) + with m.If(self.first > self.last): + m.d.comb += last_corrected.eq(self.input_width + self.last) + with m.Else(): + m.d.comb += last_corrected.eq(self.last) + + mask = Signal.like(double_input) + m.d.comb += mask.eq((1 << last_corrected) - 1) + + multi_enc_input = (double_input & mask) >> self.first + + m.submodules.multi_enc = multi_enc = MultiPriorityEncoder(self.input_width, self.outputs_count) + m.d.comb += multi_enc.input.eq(multi_enc_input) + for k in range(self.outputs_count): + moved_out = Signal(range(2 * self.input_width)) + m.d.comb += moved_out.eq(multi_enc.outputs[k] + self.first) + corrected_out = Mux(moved_out >= self.input_width, moved_out - self.input_width, moved_out) + + m.d.comb += self.outputs[k].eq(corrected_out) + m.d.comb += self.valids[k].eq(multi_enc.valids[k]) + return m