Skip to content

Commit

Permalink
CLZ and CTZ for non power of two inputs (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hazardu authored Dec 17, 2024
1 parent 9eb4d17 commit ccdb153
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 33 deletions.
36 changes: 21 additions & 15 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
count_leading_zeros,
count_trailing_zeros,
)
from amaranth.utils import ceil_log2


class TestAlignToPowerOfTwo(unittest.TestCase):
Expand Down Expand Up @@ -93,10 +94,10 @@ def test_popcount(self, size):


class CLZTestCircuit(Elaboratable):
def __init__(self, xlen_log: int):
self.sig_in = Signal(1 << xlen_log)
self.sig_out = Signal(xlen_log + 1)
self.xlen_log = xlen_log
def __init__(self, xlen: int):
self.sig_in = Signal(xlen)
self.sig_out = Signal(ceil_log2(xlen) + 1)
self.xlen = xlen

def elaborate(self, platform):
m = Module()
Expand All @@ -109,7 +110,7 @@ def elaborate(self, platform):
return m


@pytest.mark.parametrize("size", range(1, 7))
@pytest.mark.parametrize("size", [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 97, 98, 127, 128])
class TestCountLeadingZeros(TestCaseWithSimulator):
@pytest.fixture(scope="function", autouse=True)
def setup_fixture(self, size):
Expand All @@ -121,25 +122,28 @@ def setup_fixture(self, size):
def check(self, sim: TestbenchContext, n):
sim.set(self.m.sig_in, n)
out_clz = sim.get(self.m.sig_out)
assert out_clz == (2**self.size) - n.bit_length(), f"{n:x}"
expected = (self.size) - n.bit_length()
assert out_clz == expected, f"Incorrect result: got {out_clz}\t expected: {expected}"

async def process(self, sim: TestbenchContext):
for i in range(self.test_number):
n = random.randrange(2**self.size)
n = random.randrange(self.size)
self.check(sim, n)
sim.delay(1e-6)
self.check(sim, 2**self.size - 1)
await sim.delay(1e-6)
self.check(sim, 0)

def test_count_leading_zeros(self, size):
with self.run_simulation(self.m) as sim:
sim.add_testbench(self.process)


class CTZTestCircuit(Elaboratable):
def __init__(self, xlen_log: int):
self.sig_in = Signal(1 << xlen_log)
self.sig_out = Signal(xlen_log + 1)
self.xlen_log = xlen_log
def __init__(self, xlen: int):
self.sig_in = Signal(xlen)
self.sig_out = Signal(ceil_log2(xlen) + 1)
self.xlen = xlen

def elaborate(self, platform):
m = Module()
Expand All @@ -152,7 +156,7 @@ def elaborate(self, platform):
return m


@pytest.mark.parametrize("size", range(1, 7))
@pytest.mark.parametrize("size", [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 97, 98, 127, 128])
class TestCountTrailingZeros(TestCaseWithSimulator):
@pytest.fixture(scope="function", autouse=True)
def setup_fixture(self, size):
Expand All @@ -167,7 +171,7 @@ def check(self, sim: TestbenchContext, n):

expected = 0
if n == 0:
expected = 2**self.size
expected = self.size
else:
while (n & 1) == 0:
expected += 1
Expand All @@ -177,10 +181,12 @@ def check(self, sim: TestbenchContext, n):

async def process(self, sim: TestbenchContext):
for i in range(self.test_number):
n = random.randrange(2**self.size)
n = random.randrange(self.size)
self.check(sim, n)
await sim.delay(1e-6)
self.check(sim, 2**self.size - 1)
self.check(sim, self.size - 1)
await sim.delay(1e-6)
self.check(sim, 0)

def test_count_trailing_zeros(self, size):
with self.run_simulation(self.m) as sim:
Expand Down
5 changes: 2 additions & 3 deletions transactron/utils/amaranth_ext/coding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (C) 2019-2024 Amaranth HDL contributors

from amaranth import *
from transactron.utils.amaranth_ext.functions import count_leading_zeros


__all__ = [
Expand Down Expand Up @@ -84,9 +85,7 @@ def __init__(self, width: int):

def elaborate(self, platform):
m = Module()
for j in reversed(range(self.width)):
with m.If(self.i[j]):
m.d.comb += self.o.eq(j)
m.d.comb += self.o.eq(count_leading_zeros(self.i))
m.d.comb += self.n.eq(self.i == 0)
return m

Expand Down
22 changes: 7 additions & 15 deletions transactron/utils/amaranth_ext/functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any
from amaranth import *
from amaranth.hdl import ShapeCastable, ValueCastable
from amaranth.utils import bits_for, exact_log2
from amaranth.utils import bits_for, ceil_log2
from amaranth.lib import data
from collections.abc import Iterable, Mapping

Expand Down Expand Up @@ -59,28 +59,20 @@ def iter(s: Value, step: int) -> Value:

return result

try:
xlen_log = exact_log2(len(s))
except ValueError:
raise NotImplementedError("CountLeadingZeros - only sizes aligned to power of 2 are supperted")

value = iter(s, xlen_log)
slen = len(s)
slen_log = ceil_log2(slen)
closest_pow_2_of_s = 2**slen_log
zeros_prepend_count = closest_pow_2_of_s - slen
value = iter(Cat(C(0, shape=zeros_prepend_count), s), slen_log)

# 0 number edge case
# if s == 0 then iter() returns value off by 1
# this switch negates this effect
high_bit = 1 << xlen_log

result = Mux(s.any(), value, high_bit)
result = Mux(s.any(), value, slen)
return result


def count_trailing_zeros(s: Value) -> Value:
try:
exact_log2(len(s))
except ValueError:
raise NotImplementedError("CountTrailingZeros - only sizes aligned to power of 2 are supperted")

return count_leading_zeros(s[::-1])


Expand Down

0 comments on commit ccdb153

Please sign in to comment.