From ccdb153b74f232beebc885466f53ebdab4a9a534 Mon Sep 17 00:00:00 2001 From: "Hazard (Cyprian Skrzypczak)" Date: Tue, 17 Dec 2024 11:55:34 +0100 Subject: [PATCH] CLZ and CTZ for non power of two inputs (#29) --- test/utils/test_utils.py | 36 ++++++++++++--------- transactron/utils/amaranth_ext/coding.py | 5 ++- transactron/utils/amaranth_ext/functions.py | 22 ++++--------- 3 files changed, 30 insertions(+), 33 deletions(-) diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py index 3e29d8a..e403a02 100644 --- a/test/utils/test_utils.py +++ b/test/utils/test_utils.py @@ -11,6 +11,7 @@ count_leading_zeros, count_trailing_zeros, ) +from amaranth.utils import ceil_log2 class TestAlignToPowerOfTwo(unittest.TestCase): @@ -93,10 +94,10 @@ def test_popcount(self, size): class CLZTestCircuit(Elaboratable): - def __init__(self, xlen_log: int): - self.sig_in = Signal(1 << xlen_log) - self.sig_out = Signal(xlen_log + 1) - self.xlen_log = xlen_log + def __init__(self, xlen: int): + self.sig_in = Signal(xlen) + self.sig_out = Signal(ceil_log2(xlen) + 1) + self.xlen = xlen def elaborate(self, platform): m = Module() @@ -109,7 +110,7 @@ def elaborate(self, platform): return m -@pytest.mark.parametrize("size", range(1, 7)) +@pytest.mark.parametrize("size", [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 97, 98, 127, 128]) class TestCountLeadingZeros(TestCaseWithSimulator): @pytest.fixture(scope="function", autouse=True) def setup_fixture(self, size): @@ -121,14 +122,17 @@ def setup_fixture(self, size): def check(self, sim: TestbenchContext, n): sim.set(self.m.sig_in, n) out_clz = sim.get(self.m.sig_out) - assert out_clz == (2**self.size) - n.bit_length(), f"{n:x}" + expected = (self.size) - n.bit_length() + assert out_clz == expected, f"Incorrect result: got {out_clz}\t expected: {expected}" async def process(self, sim: TestbenchContext): for i in range(self.test_number): - n = random.randrange(2**self.size) + n = random.randrange(self.size) self.check(sim, n) sim.delay(1e-6) self.check(sim, 2**self.size - 1) + await sim.delay(1e-6) + self.check(sim, 0) def test_count_leading_zeros(self, size): with self.run_simulation(self.m) as sim: @@ -136,10 +140,10 @@ def test_count_leading_zeros(self, size): class CTZTestCircuit(Elaboratable): - def __init__(self, xlen_log: int): - self.sig_in = Signal(1 << xlen_log) - self.sig_out = Signal(xlen_log + 1) - self.xlen_log = xlen_log + def __init__(self, xlen: int): + self.sig_in = Signal(xlen) + self.sig_out = Signal(ceil_log2(xlen) + 1) + self.xlen = xlen def elaborate(self, platform): m = Module() @@ -152,7 +156,7 @@ def elaborate(self, platform): return m -@pytest.mark.parametrize("size", range(1, 7)) +@pytest.mark.parametrize("size", [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 97, 98, 127, 128]) class TestCountTrailingZeros(TestCaseWithSimulator): @pytest.fixture(scope="function", autouse=True) def setup_fixture(self, size): @@ -167,7 +171,7 @@ def check(self, sim: TestbenchContext, n): expected = 0 if n == 0: - expected = 2**self.size + expected = self.size else: while (n & 1) == 0: expected += 1 @@ -177,10 +181,12 @@ def check(self, sim: TestbenchContext, n): async def process(self, sim: TestbenchContext): for i in range(self.test_number): - n = random.randrange(2**self.size) + n = random.randrange(self.size) self.check(sim, n) await sim.delay(1e-6) - self.check(sim, 2**self.size - 1) + self.check(sim, self.size - 1) + await sim.delay(1e-6) + self.check(sim, 0) def test_count_trailing_zeros(self, size): with self.run_simulation(self.m) as sim: diff --git a/transactron/utils/amaranth_ext/coding.py b/transactron/utils/amaranth_ext/coding.py index 5360579..392d2d6 100644 --- a/transactron/utils/amaranth_ext/coding.py +++ b/transactron/utils/amaranth_ext/coding.py @@ -2,6 +2,7 @@ # Copyright (C) 2019-2024 Amaranth HDL contributors from amaranth import * +from transactron.utils.amaranth_ext.functions import count_leading_zeros __all__ = [ @@ -84,9 +85,7 @@ def __init__(self, width: int): def elaborate(self, platform): m = Module() - for j in reversed(range(self.width)): - with m.If(self.i[j]): - m.d.comb += self.o.eq(j) + m.d.comb += self.o.eq(count_leading_zeros(self.i)) m.d.comb += self.n.eq(self.i == 0) return m diff --git a/transactron/utils/amaranth_ext/functions.py b/transactron/utils/amaranth_ext/functions.py index d4e634a..d046240 100644 --- a/transactron/utils/amaranth_ext/functions.py +++ b/transactron/utils/amaranth_ext/functions.py @@ -1,7 +1,7 @@ from typing import Any from amaranth import * from amaranth.hdl import ShapeCastable, ValueCastable -from amaranth.utils import bits_for, exact_log2 +from amaranth.utils import bits_for, ceil_log2 from amaranth.lib import data from collections.abc import Iterable, Mapping @@ -59,28 +59,20 @@ def iter(s: Value, step: int) -> Value: return result - try: - xlen_log = exact_log2(len(s)) - except ValueError: - raise NotImplementedError("CountLeadingZeros - only sizes aligned to power of 2 are supperted") - - value = iter(s, xlen_log) + slen = len(s) + slen_log = ceil_log2(slen) + closest_pow_2_of_s = 2**slen_log + zeros_prepend_count = closest_pow_2_of_s - slen + value = iter(Cat(C(0, shape=zeros_prepend_count), s), slen_log) # 0 number edge case # if s == 0 then iter() returns value off by 1 # this switch negates this effect - high_bit = 1 << xlen_log - - result = Mux(s.any(), value, high_bit) + result = Mux(s.any(), value, slen) return result def count_trailing_zeros(s: Value) -> Value: - try: - exact_log2(len(s)) - except ValueError: - raise NotImplementedError("CountTrailingZeros - only sizes aligned to power of 2 are supperted") - return count_leading_zeros(s[::-1])