Skip to content

Commit

Permalink
LZA (Leading zeros anticipation) (kuznia-rdzeni#741)
Browse files Browse the repository at this point in the history
  • Loading branch information
Durchbruchswagen authored and tilk committed Dec 16, 2024
1 parent 576adff commit df115b7
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 0 deletions.
111 changes: 111 additions & 0 deletions coreblocks/func_blocks/fu/fpu/lza.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from amaranth import *
from amaranth.utils import ceil_log2
from transactron import TModule, Method, def_method
from coreblocks.func_blocks.fu.fpu.fpu_common import FPUParams
from transactron.utils.amaranth_ext import count_leading_zeros


class LZAMethodLayout:
"""LZA module layouts for methods
Parameters
----------
fpu_params: FPUParams
FPU parameters
"""

def __init__(self, *, fpu_params: FPUParams):
"""
sig_a - significand of a
sig_b - significand of b
carry - indicates if we want to predict result of a+b or a+b+1
shift_amount - position to shift needed to normalize number
is_zero - indicates if result is zero
"""
self.predict_in_layout = [
("sig_a", fpu_params.sig_width),
("sig_b", fpu_params.sig_width),
("carry", 1),
]
self.predict_out_layout = [
("shift_amount", range(fpu_params.sig_width)),
("is_zero", 1),
]


class LZAModule(Elaboratable):
"""LZA module
Based on: https://userpages.cs.umbc.edu/phatak/645/supl/lza/lza-survey-arith01.pdf
After performing subtracion, we may have to normalize floating point numbers and
For that, we have to know the number of leading zeros.
The most basic approach includes using LZC (leading zero counter) after subtracion,
a more advanced approach includes using LZA (Leading Zero Anticipator) to predict the number of
leading zeroes. It is worth noting that this LZA module works under assumptions that
significands are in two's complement and that before complementation sig_a was greater
or equal to sig_b. Another thing worth noting is that LZA works with error = 1.
That means that if 'n' is the result of the LZA module, in reality, to normalize
number we may have to shift left by 'n' or 'n+1'. There are few techniques of
dealing with that error like specially designed shifters or predicting the error
but the most basic approach is to just use multiplexer after shifter to perform
one more shift left if necessary.
Parameters
----------
fpu_params: FPUParams
FPU rounding module parameters
Attributes
----------
predict_request: Method
Transactional method for initiating leading zeros prediction.
Takes 'predict_in_layout' as argument
Returns shift amount as 'predict_out_layout'
"""

def __init__(self, *, fpu_params: FPUParams):

self.lza_params = fpu_params
self.method_layouts = LZAMethodLayout(fpu_params=self.lza_params)
self.predict_request = Method(
i=self.method_layouts.predict_in_layout,
o=self.method_layouts.predict_out_layout,
)

def elaborate(self, platform):
m = TModule()

@def_method(m, self.predict_request)
def _(sig_a, sig_b, carry):
f_size = 2 ** ceil_log2(self.lza_params.sig_width)
filler_size = f_size - self.lza_params.sig_width
lower_ones = Const((2**filler_size) - 1, f_size)

t = Signal(self.lza_params.sig_width + 1)
g = Signal(self.lza_params.sig_width + 1)
z = Signal(self.lza_params.sig_width + 1)
f = Signal(f_size)
shift_amount = Signal(range(self.lza_params.sig_width))
is_zero = Signal(1)

m.d.av_comb += t.eq((sig_a ^ sig_b) << 1)
m.d.av_comb += g.eq((sig_a & sig_b) << 1)
m.d.av_comb += z.eq(((sig_a | sig_b) << 1))
with m.If(carry):
m.d.av_comb += g[0].eq(1)
m.d.av_comb += z[0].eq(1)

for i in reversed(range(1, self.lza_params.sig_width + 1)):
m.d.av_comb += f[i + filler_size - 1].eq((t[i] ^ z[i - 1]))

m.d.av_comb += shift_amount.eq(0)
m.d.av_comp += f.eq(f | lower_ones)
m.d.av_comb += shift_amount.eq(count_leading_zeros(f))

m.d.av_comb += is_zero.eq((carry & t[1 : self.lza_params.sig_width].all()))

return {
"shift_amount": shift_amount,
"is_zero": is_zero,
}

return m
112 changes: 112 additions & 0 deletions test/func_blocks/fu/fpu/test_lza.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import random
from coreblocks.func_blocks.fu.fpu.lza import *
from coreblocks.func_blocks.fu.fpu.fpu_common import FPUParams
from transactron import TModule
from transactron.lib import AdapterTrans
from transactron.testing import *
from amaranth import *


def clz(sig_a, sig_b, carry, sig_width):
zeros = 0
msb_bit_mask = 1 << (sig_width - 1)
sum = sig_a + sig_b + carry
while 1:
if not (sum & msb_bit_mask):
zeros += 1
sum = sum << 1
else:
return zeros


class TestLZA(TestCaseWithSimulator):
class LZAModuleTest(Elaboratable):
def __init__(self, params: FPUParams):
self.params = params

def elaborate(self, platform):
m = TModule()
m.submodules.lza = lza = self.lza_module = LZAModule(fpu_params=self.params)
m.submodules.predict = self.predict_request_adapter = TestbenchIO(AdapterTrans(lza.predict_request))
return m

def test_manual(self):
params = FPUParams(sig_width=24, exp_width=8)
lza = TestLZA.LZAModuleTest(params)

async def random_test(sim: TestbenchContext, seed: int, iters: int):
xor_mask = (2**params.sig_width) - 1
random.seed(seed)
for _ in range(iters):
sig_a = random.randint(1 << (params.sig_width - 1), (2**params.sig_width) - 1)
sig_b = random.randint(1 << (params.sig_width - 1), sig_a)
sig_b = (sig_b ^ xor_mask) | (1 << params.sig_width)
resp = await lza.predict_request_adapter.call(sim, {"sig_a": sig_a, "sig_b": sig_b, "carry": 0})
pred_lz = resp["shift_amount"]
true_lz = clz(sig_a, sig_b, 0, params.sig_width)
assert pred_lz == true_lz or (pred_lz + 1) == true_lz

async def lza_test(sim: TestbenchContext):
test_cases = [
{
"sig_a": 16368512,
"sig_b": 409600,
"carry": 0,
},
{
"sig_a": 0,
"sig_b": (2**24) - 1,
"carry": 0,
},
{
"sig_a": (2**24) // 2,
"sig_b": (2**24) // 2,
"carry": 0,
},
{
"sig_a": 12582912,
"sig_b": 12550144,
"carry": 0,
},
{
"sig_a": 16744448,
"sig_b": 12615680,
"carry": 0,
},
{
"sig_a": 8421376,
"sig_b": 8421376,
"carry": 0,
},
]
expected_results = [
{"shift_amount": 13, "is_zero": 0},
{"shift_amount": 13, "is_zero": 0},
{"shift_amount": 23, "is_zero": 0},
{"shift_amount": 0, "is_zero": 1},
{"shift_amount": 0, "is_zero": 0},
{"shift_amount": 23, "is_zero": 0},
{"shift_amount": 0, "is_zero": 0},
{"shift_amount": 0, "is_zero": 0},
{"shift_amount": 0, "is_zero": 0},
{"shift_amount": 0, "is_zero": 0},
{"shift_amount": 7, "is_zero": 0},
{"shift_amount": 7, "is_zero": 0},
]
for i in range(len(test_cases)):

resp = await lza.predict_request_adapter.call(sim, test_cases[i])
assert resp["shift_amount"] == expected_results[2 * i]["shift_amount"]
assert resp["is_zero"] == expected_results[2 * i]["is_zero"]

test_cases[i]["carry"] = 1
resp = await lza.predict_request_adapter.call(sim, test_cases[i])
assert resp["shift_amount"] == expected_results[2 * i + 1]["shift_amount"]
assert resp["is_zero"] == expected_results[2 * i + 1]["is_zero"]

async def test_process(sim: TestbenchContext):
await lza_test(sim)
await random_test(sim, 2024, 20)

with self.run_simulation(lza) as sim:
sim.add_testbench(test_process)

0 comments on commit df115b7

Please sign in to comment.