diff --git a/coreblocks/func_blocks/fu/fpu/lza.py b/coreblocks/func_blocks/fu/fpu/lza.py index 00709ad80..3e9695886 100644 --- a/coreblocks/func_blocks/fu/fpu/lza.py +++ b/coreblocks/func_blocks/fu/fpu/lza.py @@ -12,6 +12,13 @@ class LZAMethodLayout: """ def __init__(self, *, fpu_params: FPUParams): + """ + sig_a - significand of a + sig_b - significand of b + carry - indicates if we want to predict result of a+b or a+b+1 + shift_amount - position to shift needed to normalize number + is_zero - indicates if result is zero + """ self.predict_in_layout = [ ("sig_a", fpu_params.sig_width), ("sig_b", fpu_params.sig_width), @@ -25,6 +32,19 @@ def __init__(self, *, fpu_params: FPUParams): class LZAModule(Elaboratable): """LZA module + Based on: https://userpages.cs.umbc.edu/phatak/645/supl/lza/lza-survey-arith01.pdf + After performing subtracion we may have to normalize floating point number and + for that we have to know the number of leading zeros. + Most basic approach includes using LZC (leading zero counter) after subtracion. + More advanced approach includes using LZA (Leading Zero Anticipator) to predict number + leading zeroes. It is worth noting that this LZA module works under assumptions that + significands are in two's complement and that before complementation sig_a was greater + or equal to sig_b. Another think worth noting is that LZA works with error = 1. + That means that if 'n' is the result of LZA module the in reality to normalize + number we may have to shift left by 'n' or 'n+1'. There are few techniques of + dealing with that error like specialy designed shifters or predicting the error + but most basic approach is to just use multiplexer after shifter to perform + one more shift left if necessary. Parameters ---------- @@ -52,7 +72,7 @@ def elaborate(self, platform): m = TModule() @def_method(m, self.predict_request) - def _(arg): + def _(sig_a, sig_b, carry): t = Signal(self.lza_params.sig_width + 1) g = Signal(self.lza_params.sig_width + 1) z = Signal(self.lza_params.sig_width + 1) @@ -60,23 +80,22 @@ def _(arg): shift_amount = Signal(range(self.lza_params.sig_width)) is_zero = Signal(1) - m.d.av_comb += t.eq((arg.sig_a ^ arg.sig_b) << 1) - m.d.av_comb += g.eq((arg.sig_a & arg.sig_b) << 1) - m.d.av_comb += z.eq(((~(arg.sig_a) & ~(arg.sig_b)) << 1)) - with m.If(arg.carry): + m.d.av_comb += t.eq((sig_a ^ sig_b) << 1) + m.d.av_comb += g.eq((sig_a & sig_b) << 1) + m.d.av_comb += z.eq(((sig_a | sig_b) << 1)) + with m.If(carry): m.d.av_comb += g[0].eq(1) - with m.Else(): m.d.av_comb += z[0].eq(1) for i in reversed(range(1, self.lza_params.sig_width + 1)): - m.d.av_comb += f[i - 1].eq((t[i] ^ ~(z[i - 1]))) + m.d.av_comb += f[i - 1].eq((t[i] ^ z[i - 1])) m.d.av_comb += shift_amount.eq(0) for i in reversed(range(self.lza_params.sig_width)): with m.If(f[self.lza_params.sig_width - i - 1]): m.d.av_comb += shift_amount.eq(i) - m.d.av_comb += is_zero.eq((arg.carry & t[1 : self.lza_params.sig_width].all())) + m.d.av_comb += is_zero.eq((carry & t[1 : self.lza_params.sig_width].all())) return { "shift_amount": shift_amount, diff --git a/test/func_blocks/fu/test_lza.py b/test/func_blocks/fu/test_lza.py index f77ab9563..522684ecd 100644 --- a/test/func_blocks/fu/test_lza.py +++ b/test/func_blocks/fu/test_lza.py @@ -1,5 +1,6 @@ from coreblocks.func_blocks.fu.fpu.lza import * from coreblocks.func_blocks.fu.fpu.fpu_common import FPUParams +from random import randint from transactron import TModule from transactron.lib import AdapterTrans from transactron.testing import * @@ -37,6 +38,27 @@ def test_manual(self): help_values = TestLZA.HelpValues(params) lza = TestLZA.LZAModuleTest(params) + def clz(sig_a, sig_b, carry): + zeros = 0 + msb_bit_mask = 1 << (params.sig_width - 1) + sum = sig_a + sig_b + carry + while 1: + if not (sum & msb_bit_mask): + zeros += 1 + sum = sum << 1 + else: + return zeros + + def random_test(): + xor_mask = (2**params.sig_width) - 1 + sig_a = randint(1 << (params.sig_width - 1), (2**params.sig_width) - 1) + sig_b = randint(1 << (params.sig_width - 1), sig_a) + sig_b = (sig_b ^ xor_mask) | (1 << params.sig_width) + resp = yield from lza.predict_request_adapter.call({"sig_a": sig_a, "sig_b": sig_b, "carry": 0}) + pred_lz = resp["shift_amount"] + true_lz = clz(sig_a, sig_b, 0) + assert pred_lz == true_lz or (pred_lz + 1) == true_lz + def lza_test(): test_cases = [ { @@ -121,6 +143,7 @@ def lza_test(): def test_process(): yield from lza_test() + yield from random_test() with self.run_simulation(lza) as sim: - sim.add_sync_process(test_process) + sim.add_process(test_process)