Skip to content

Commit

Permalink
FPU rounding module (#728)
Browse files Browse the repository at this point in the history
  • Loading branch information
Durchbruchswagen authored Nov 12, 2024
1 parent def471b commit 34565a3
Show file tree
Hide file tree
Showing 6 changed files with 912 additions and 0 deletions.
Empty file.
38 changes: 38 additions & 0 deletions coreblocks/func_blocks/fu/fpu/fpu_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from amaranth.lib import enum


class RoundingModes(enum.Enum):
ROUND_UP = 3
ROUND_DOWN = 2
ROUND_ZERO = 1
ROUND_NEAREST_EVEN = 0
ROUND_NEAREST_AWAY = 4


class Errors(enum.IntFlag):
INVALID_OPERATION = enum.auto()
DIVISION_BY_ZERO = enum.auto()
OVERFLOW = enum.auto()
UNDERFLOW = enum.auto()
INEXACT = enum.auto()


class FPUParams:
"""FPU parameters
Parameters
----------
sig_width: int
Width of significand, including implicit bit
exp_width: int
Width of exponent
"""

def __init__(
self,
*,
sig_width: int = 24,
exp_width: int = 8,
):
self.sig_width = sig_width
self.exp_width = exp_width
176 changes: 176 additions & 0 deletions coreblocks/func_blocks/fu/fpu/fpu_error_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from amaranth import *
from transactron import TModule, Method, def_method
from coreblocks.func_blocks.fu.fpu.fpu_common import (
RoundingModes,
FPUParams,
Errors,
)


class FPUErrorMethodLayout:
"""FPU error checking module layouts for methods
Parameters
----------
fpu_params: FPUParams
FPU parameters
"""

def __init__(self, *, fpu_params: FPUParams):
"""
input_inf is a flag that comes from previous stage.
Its purpose is to indicate that the infinity on input
is a result of infinity arithmetic and not a result of overflow
"""
self.error_in_layout = [
("sign", 1),
("sig", fpu_params.sig_width),
("exp", fpu_params.exp_width),
("rounding_mode", RoundingModes),
("inexact", 1),
("invalid_operation", 1),
("division_by_zero", 1),
("input_inf", 1),
]
self.error_out_layout = [
("sign", 1),
("sig", fpu_params.sig_width),
("exp", fpu_params.exp_width),
("errors", Errors),
]


class FPUErrorModule(Elaboratable):
"""FPU error checking module
Parameters
----------
fpu_params: FPUParams
FPU rounding module parameters
Attributes
----------
error_checking_request: Method
Transactional method for initiating error checking of a floating point number.
Takes 'error_in_layout' as argument
Returns final number and errors as 'error_out_layout'
"""

def __init__(self, *, fpu_params: FPUParams):

self.fpu_errors_params = fpu_params
self.method_layouts = FPUErrorMethodLayout(fpu_params=self.fpu_errors_params)
self.error_checking_request = Method(
i=self.method_layouts.error_in_layout,
o=self.method_layouts.error_out_layout,
)

def elaborate(self, platform):
m = TModule()

max_exp = C(
2 ** (self.fpu_errors_params.exp_width) - 1,
unsigned(self.fpu_errors_params.exp_width),
)
max_normal_exp = C(
2 ** (self.fpu_errors_params.exp_width) - 2,
unsigned(self.fpu_errors_params.exp_width),
)
max_sig = C(
2 ** (self.fpu_errors_params.sig_width) - 1,
unsigned(self.fpu_errors_params.sig_width),
)

overflow = Signal()
underflow = Signal()
inexact = Signal()
tininess = Signal()

final_exp = Signal(self.fpu_errors_params.exp_width)
final_sig = Signal(self.fpu_errors_params.sig_width)
final_sign = Signal()
final_errors = Signal(5)

@def_method(m, self.error_checking_request)
def _(arg):
is_nan = arg.invalid_operation | ((arg.exp == max_exp) & (arg.sig.any()))
is_inf = arg.division_by_zero | arg.input_inf
input_not_special = ~(is_nan) & ~(is_inf)
m.d.av_comb += overflow.eq(input_not_special & (arg.exp == max_exp))
m.d.av_comb += tininess.eq((arg.exp == 0) & (~arg.sig[-1]))
m.d.av_comb += inexact.eq(overflow | (input_not_special & arg.inexact))
m.d.av_comb += underflow.eq(tininess & inexact)

with m.If(is_nan | is_inf):

m.d.av_comb += final_exp.eq(arg.exp)
m.d.av_comb += final_sig.eq(arg.sig)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Elif(overflow):

with m.Switch(arg.rounding_mode):
with m.Case(RoundingModes.ROUND_NEAREST_AWAY, RoundingModes.ROUND_NEAREST_EVEN):

m.d.av_comb += final_exp.eq(max_exp)
m.d.av_comb += final_sig.eq(0)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Case(RoundingModes.ROUND_ZERO):

m.d.av_comb += final_exp.eq(max_normal_exp)
m.d.av_comb += final_sig.eq(max_sig)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Case(RoundingModes.ROUND_DOWN):

with m.If(arg.sign):

m.d.av_comb += final_exp.eq(max_exp)
m.d.av_comb += final_sig.eq(0)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Else():

m.d.av_comb += final_exp.eq(max_normal_exp)
m.d.av_comb += final_sig.eq(max_sig)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Case(RoundingModes.ROUND_UP):

with m.If(arg.sign):

m.d.av_comb += final_exp.eq(max_normal_exp)
m.d.av_comb += final_sig.eq(max_sig)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Else():

m.d.av_comb += final_exp.eq(max_exp)
m.d.av_comb += final_sig.eq(0)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Else():
with m.If((arg.exp == 0) & (arg.sig[-1] == 1)):
m.d.av_comb += final_exp.eq(1)
with m.Else():
m.d.av_comb += final_exp.eq(arg.exp)
m.d.av_comb += final_sig.eq(arg.sig)
m.d.av_comb += final_sign.eq(arg.sign)

m.d.av_comb += final_errors.eq(
Mux(arg.invalid_operation, Errors.INVALID_OPERATION, 0)
| Mux(arg.division_by_zero, Errors.DIVISION_BY_ZERO, 0)
| Mux(overflow, Errors.OVERFLOW, 0)
| Mux(underflow, Errors.UNDERFLOW, 0)
| Mux(inexact, Errors.INEXACT, 0)
)

return {
"exp": final_exp,
"sig": final_sig,
"sign": final_sign,
"errors": final_errors,
}

return m
117 changes: 117 additions & 0 deletions coreblocks/func_blocks/fu/fpu/fpu_rounding_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from amaranth import *
from transactron import TModule, Method, def_method
from coreblocks.func_blocks.fu.fpu.fpu_common import (
RoundingModes,
FPUParams,
)


class FPURoudningMethodLayout:
"""FPU Rounding module layouts for methods
Parameters
----------
fpu_params: FPUParams
FPU parameters
"""

def __init__(self, *, fpu_params: FPUParams):
self.rounding_in_layout = [
("sign", 1),
("sig", fpu_params.sig_width),
("exp", fpu_params.exp_width),
("round_bit", 1),
("sticky_bit", 1),
("rounding_mode", RoundingModes),
]
self.rounding_out_layout = [
("sig", fpu_params.sig_width),
("exp", fpu_params.exp_width),
("inexact", 1),
]


class FPURounding(Elaboratable):
"""FPU Rounding module
Parameters
----------
fpu_params: FPUParams
FPU parameters
Attributes
----------
rounding_request: Method
Transactional method for initiating rounding of a floating point number.
Takes 'rounding_in_layout' as argument
Returns rounded number and errors as 'rounding_out_layout'
"""

def __init__(self, *, fpu_params: FPUParams):

self.fpu_rounding_params = fpu_params
self.method_layouts = FPURoudningMethodLayout(fpu_params=self.fpu_rounding_params)
self.rounding_request = Method(
i=self.method_layouts.rounding_in_layout,
o=self.method_layouts.rounding_out_layout,
)

def elaborate(self, platform):
m = TModule()

add_one = Signal()
inc_rtnte = Signal()
inc_rtnta = Signal()
inc_rtpi = Signal()
inc_rtmi = Signal()

rounded_sig = Signal(self.fpu_rounding_params.sig_width + 1)
normalised_sig = Signal(self.fpu_rounding_params.sig_width)
rounded_exp = Signal(self.fpu_rounding_params.exp_width)

final_round_bit = Signal()
final_sticky_bit = Signal()

inexact = Signal()

@def_method(m, self.rounding_request)
def _(arg):

m.d.av_comb += inc_rtnte.eq(
(arg.rounding_mode == RoundingModes.ROUND_NEAREST_EVEN)
& (arg.round_bit & (arg.sticky_bit | arg.sig[0]))
)
m.d.av_comb += inc_rtnta.eq((arg.rounding_mode == RoundingModes.ROUND_NEAREST_AWAY) & (arg.round_bit))
m.d.av_comb += inc_rtpi.eq(
(arg.rounding_mode == RoundingModes.ROUND_UP) & (~arg.sign & (arg.round_bit | arg.sticky_bit))
)
m.d.av_comb += inc_rtmi.eq(
(arg.rounding_mode == RoundingModes.ROUND_DOWN) & (arg.sign & (arg.round_bit | arg.sticky_bit))
)

m.d.av_comb += add_one.eq(inc_rtmi | inc_rtnta | inc_rtnte | inc_rtpi)

m.d.av_comb += rounded_sig.eq(arg.sig + add_one)

with m.If(rounded_sig[-1]):

m.d.av_comb += normalised_sig.eq(rounded_sig >> 1)
m.d.av_comb += final_round_bit.eq(rounded_sig[0])
m.d.av_comb += final_sticky_bit.eq(arg.round_bit | arg.sticky_bit)
m.d.av_comb += rounded_exp.eq(arg.exp + 1)

with m.Else():
m.d.av_comb += normalised_sig.eq(rounded_sig)
m.d.av_comb += final_round_bit.eq(arg.round_bit)
m.d.av_comb += final_sticky_bit.eq(arg.sticky_bit)
m.d.av_comb += rounded_exp.eq(arg.exp)

m.d.av_comb += inexact.eq(final_round_bit | final_sticky_bit)

return {
"exp": rounded_exp,
"sig": normalised_sig,
"inexact": inexact,
}

return m
Loading

0 comments on commit 34565a3

Please sign in to comment.