diff --git a/test/transactron/utils/test_amaranth_ext.py b/test/transactron/utils/test_amaranth_ext.py index fa291567b..7943ccb76 100644 --- a/test/transactron/utils/test_amaranth_ext.py +++ b/test/transactron/utils/test_amaranth_ext.py @@ -3,140 +3,73 @@ from transactron.utils.amaranth_ext import MultiPriorityEncoder, RingMultiPriorityEncoder -class TestMultiPriorityEncoder(TestCaseWithSimulator): - def get_expected(self, input): - places = [] - for i in range(self.input_width): - if input % 2: - places.append(i) - input //= 2 - places += [None] * self.output_count - return places - - def process(self): - for _ in range(self.test_number): - input = random.randrange(2**self.input_width) - yield self.circ.input.eq(input) - yield Settle() - expected_output = self.get_expected(input) - for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids): - if ex is None: - assert (yield valid) == 0 - else: - assert (yield valid) == 1 - assert (yield real) == ex - yield Delay(1e-7) +def get_expected_multi(input_width, output_count, input, *args): + places = [] + for i in range(input_width): + if input % 2: + places.append(i) + input //= 2 + places += [None] * output_count + return places + + +def get_expected_ring(input_width, output_count, input, first, last): + places = [] + input = (input << input_width) + input + if last < first: + last += input_width + for i in range(2 * input_width): + if i >= first and i < last and input % 2: + places.append(i % input_width) + input //= 2 + places += [None] * output_count + return places + + +@pytest.mark.parametrize( + "test_class, verif_f", + [(MultiPriorityEncoder, get_expected_multi), (RingMultiPriorityEncoder, get_expected_ring)], + ids=["MultiPriorityEncoder", "RingMultiPriorityEncoder"], +) +class TestPriorityEncoder(TestCaseWithSimulator): + def process(self, get_expected): + def f(): + for _ in range(self.test_number): + input = random.randrange(2**self.input_width) + first = random.randrange(self.input_width) + last = random.randrange(self.input_width) + yield self.circ.input.eq(input) + try: + yield self.circ.first.eq(first) + yield self.circ.last.eq(last) + except AttributeError: + pass + yield Settle() + expected_output = get_expected(self.input_width, self.output_count, input, first, last) + for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids): + if ex is None: + assert (yield valid) == 0 + else: + assert (yield valid) == 1 + assert (yield real) == ex + yield Delay(1e-7) + + return f @pytest.mark.parametrize("input_width", [1, 5, 16, 23, 24]) @pytest.mark.parametrize("output_count", [1, 3, 4]) - def test_random(self, input_width, output_count): + def test_random(self, test_class, verif_f, input_width, output_count): random.seed(input_width + output_count) self.test_number = 50 self.input_width = input_width self.output_count = output_count - self.circ = MultiPriorityEncoder(self.input_width, self.output_count) + self.circ = test_class(self.input_width, self.output_count) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process) + sim.add_process(self.process(verif_f)) @pytest.mark.parametrize("name", ["prio_encoder", None]) - def test_static_create_simple(self, name): - random.seed(14) - self.test_number = 50 - self.input_width = 7 - self.output_count = 1 - - class DUT(Elaboratable): - def __init__(self, input_width, output_count, name): - self.input = Signal(input_width) - self.output_count = output_count - self.input_width = input_width - self.name = name - - def elaborate(self, platform): - m = Module() - out, val = MultiPriorityEncoder.create_simple(m, self.input_width, self.input, name=self.name) - # Save as a list to use common interface in testing - self.outputs = [out] - self.valids = [val] - return m - - self.circ = DUT(self.input_width, self.output_count, name) - - with self.run_simulation(self.circ) as sim: - sim.add_process(self.process) - - @pytest.mark.parametrize("name", ["prio_encoder", None]) - def test_static_create(self, name): - random.seed(14) - self.test_number = 50 - self.input_width = 7 - self.output_count = 2 - - class DUT(Elaboratable): - def __init__(self, input_width, output_count, name): - self.input = Signal(input_width) - self.output_count = output_count - self.input_width = input_width - self.name = name - - def elaborate(self, platform): - m = Module() - out = MultiPriorityEncoder.create(m, self.input_width, self.input, self.output_count, name=self.name) - self.outputs, self.valids = list(zip(*out)) - return m - - self.circ = DUT(self.input_width, self.output_count, name) - - with self.run_simulation(self.circ) as sim: - sim.add_process(self.process) - - -class TestRingMultiPriorityEncoder(TestCaseWithSimulator): - def get_expected(self, input, first, last): - places = [] - input = (input << self.input_width) + input - if last < first: - last += self.input_width - for i in range(2 * self.input_width): - if i >= first and i < last and input % 2: - places.append(i % self.input_width) - input //= 2 - places += [None] * self.output_count - return places - - def process(self): - for _ in range(self.test_number): - input = random.randrange(2**self.input_width) - first = random.randrange(self.input_width) - last = random.randrange(self.input_width) - yield self.circ.input.eq(input) - yield self.circ.first.eq(first) - yield self.circ.last.eq(last) - yield Settle() - expected_output = self.get_expected(input, first, last) - for ex, real, valid in zip(expected_output, self.circ.outputs, self.circ.valids): - if ex is None: - assert (yield valid) == 0 - else: - assert (yield valid) == 1 - assert (yield real) == ex - yield Delay(1e-7) - - @pytest.mark.parametrize("input_width", [1, 5, 16, 23, 24]) - @pytest.mark.parametrize("output_count", [1, 3, 4]) - def test_random(self, input_width, output_count): - random.seed(input_width + output_count) - self.test_number = 50 - self.input_width = input_width - self.output_count = output_count - self.circ = RingMultiPriorityEncoder(self.input_width, self.output_count) - - with self.run_simulation(self.circ) as sim: - sim.add_process(self.process) - - @pytest.mark.parametrize("name", ["prio_encoder", None]) - def test_static_create_simple(self, name): + def test_static_create_simple(self, test_class, verif_f, name): random.seed(14) self.test_number = 50 self.input_width = 7 @@ -153,9 +86,12 @@ def __init__(self, input_width, output_count, name): def elaborate(self, platform): m = Module() - out, val = RingMultiPriorityEncoder.create_simple( - m, self.input_width, self.input, self.first, self.last, name=self.name - ) + if test_class == MultiPriorityEncoder: + out, val = test_class.create_simple(m, self.input_width, self.input, name=self.name) + else: + out, val = test_class.create_simple( + m, self.input_width, self.input, self.first, self.last, name=self.name + ) # Save as a list to use common interface in testing self.outputs = [out] self.valids = [val] @@ -164,10 +100,10 @@ def elaborate(self, platform): self.circ = DUT(self.input_width, self.output_count, name) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process) + sim.add_process(self.process(verif_f)) @pytest.mark.parametrize("name", ["prio_encoder", None]) - def test_static_create(self, name): + def test_static_create(self, test_class, verif_f, name): random.seed(14) self.test_number = 50 self.input_width = 7 @@ -184,13 +120,16 @@ def __init__(self, input_width, output_count, name): def elaborate(self, platform): m = Module() - out = RingMultiPriorityEncoder.create( - m, self.input_width, self.input, self.first, self.last, self.output_count, name=self.name - ) + if test_class == MultiPriorityEncoder: + out = test_class.create(m, self.input_width, self.input, self.output_count, name=self.name) + else: + out = test_class.create( + m, self.input_width, self.input, self.first, self.last, self.output_count, name=self.name + ) self.outputs, self.valids = list(zip(*out)) return m self.circ = DUT(self.input_width, self.output_count, name) with self.run_simulation(self.circ) as sim: - sim.add_process(self.process) + sim.add_process(self.process(verif_f))