Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FPU rounding module #728

Merged
merged 18 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
30 changes: 30 additions & 0 deletions coreblocks/func_blocks/fu/fpu/fpu_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from amaranth.lib import enum


class RoundingModes(enum.Enum):
ROUND_UP = 3
ROUND_DOWN = 2
ROUND_ZERO = 1
ROUND_NEAREST_EVEN = 0
ROUND_NEAREST_AWAY = 4


class FPUParams:
"""FPU parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here with documentation: it would be helpful to add information that sig_width contains the implicit bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


Parameters
----------
sig_width: int
Width of significand
exp_width: int
Width of exponent
"""

def __init__(
self,
*,
sig_width: int = 24,
exp_width: int = 8,
):
self.sig_width = sig_width
self.exp_width = exp_width
168 changes: 168 additions & 0 deletions coreblocks/func_blocks/fu/fpu/fpu_error_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from amaranth import *
from transactron import TModule, Method, def_method
from coreblocks.func_blocks.fu.fpu.fpu_common import (
RoundingModes,
FPUParams,
)


class FPUErrorMethodLayout:
"""FPU error checking module layouts for methods

Parameters
----------
fpu_params: FPUParams
FPU parameters
"""

def __init__(self, *, fpu_params: FPUParams):
self.error_in_layout = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usage of input_inf as in input to previous stages is not obvious by name, comment about that here could be helpful for future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please resolve comment from previous review

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response. I pushed changes at around 3 a.m. and went to sleep with a plan to resolve comments next day, but by then you had already reviewed them. I was also thinking if maybe a better solution would be to write about this flag in some class for input layout for modules that perform arithmetic operations or something like that, but probably a small comment about why an error module needs it is warranted, so I added it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was finished, sorry!

("sign", 1),
("sig", fpu_params.sig_width),
("exp", fpu_params.exp_width),
("rounding_mode", RoundingModes),
("inexact", 1),
("invalid_operation", 1),
("division_by_zero", 1),
("input_inf", 1),
]
self.error_out_layout = [
("sign", 1),
("sig", fpu_params.sig_width),
("exp", fpu_params.exp_width),
("errors", 5),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use IntFlag enum as errors shape and when referencing them

]


class FPUErrorModule(Elaboratable):
"""FPU error checking module

Parameters
----------
fpu_params: FPUParams
FPU rounding module parameters

Attributes
----------
error_checking_request: Method
Transactional method for initiating error checking of a floating point number.
Takes 'error_in_layout' as argument
Returns final number and errors as 'error_out_layout'
"""

def __init__(self, *, fpu_params: FPUParams):

self.fpu_errors_params = fpu_params
self.method_layouts = FPUErrorMethodLayout(fpu_params=self.fpu_errors_params)
self.error_checking_request = Method(
i=self.method_layouts.error_in_layout,
o=self.method_layouts.error_out_layout,
)

def elaborate(self, platform):
m = TModule()

max_exp = C(
2 ** (self.fpu_errors_params.exp_width) - 1,
unsigned(self.fpu_errors_params.exp_width),
)
max_normal_exp = C(
2 ** (self.fpu_errors_params.exp_width) - 2,
unsigned(self.fpu_errors_params.exp_width),
)
max_sig = C(
2 ** (self.fpu_errors_params.sig_width) - 1,
unsigned(self.fpu_errors_params.sig_width),
)

overflow = Signal()
underflow = Signal()
inexact = Signal()
tininess = Signal()

final_exp = Signal(self.fpu_errors_params.exp_width)
final_sig = Signal(self.fpu_errors_params.sig_width)
final_sign = Signal()
final_errors = Signal(5)

@def_method(m, self.error_checking_request)
def _(arg):
is_nan = arg.invalid_operation | ((arg.exp == max_exp) & (arg.sig.any()))
is_inf = arg.division_by_zero | arg.input_inf
input_not_special = ~(is_nan) & ~(is_inf)
m.d.av_comb += overflow.eq(input_not_special & (arg.exp == max_exp))
m.d.av_comb += tininess.eq((arg.exp == 0) & (~arg.sig[-1]))
m.d.av_comb += inexact.eq(overflow | (input_not_special & arg.inexact))
m.d.av_comb += underflow.eq(tininess & inexact)

with m.If(is_nan | is_inf):

m.d.av_comb += final_exp.eq(arg.exp)
m.d.av_comb += final_sig.eq(arg.sig)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Elif(overflow):

with m.Switch(arg.rounding_mode):
with m.Case(RoundingModes.ROUND_NEAREST_AWAY, RoundingModes.ROUND_NEAREST_EVEN):

m.d.av_comb += final_exp.eq(max_exp)
m.d.av_comb += final_sig.eq(0)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Case(RoundingModes.ROUND_ZERO):

m.d.av_comb += final_exp.eq(max_normal_exp)
m.d.av_comb += final_sig.eq(max_sig)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Case(RoundingModes.ROUND_DOWN):

with m.If(arg.sign):

m.d.av_comb += final_exp.eq(max_exp)
m.d.av_comb += final_sig.eq(0)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Else():

m.d.av_comb += final_exp.eq(max_normal_exp)
m.d.av_comb += final_sig.eq(max_sig)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Case(RoundingModes.ROUND_UP):

with m.If(arg.sign):

m.d.av_comb += final_exp.eq(max_normal_exp)
m.d.av_comb += final_sig.eq(max_sig)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Else():

m.d.av_comb += final_exp.eq(max_exp)
m.d.av_comb += final_sig.eq(0)
m.d.av_comb += final_sign.eq(arg.sign)

with m.Else():
with m.If((arg.exp == 0) & (arg.sig[-1] == 1)):
m.d.av_comb += final_exp.eq(1)
with m.Else():
m.d.av_comb += final_exp.eq(arg.exp)
m.d.av_comb += final_sig.eq(arg.sig)
m.d.av_comb += final_sign.eq(arg.sign)

m.d.av_comb += final_errors[0].eq(arg.invalid_operation)
m.d.av_comb += final_errors[1].eq(arg.division_by_zero)
m.d.av_comb += final_errors[2].eq(overflow)
m.d.av_comb += final_errors[3].eq(underflow)
m.d.av_comb += final_errors[4].eq(inexact)

return {
"exp": final_exp,
"sig": final_sig,
"sign": final_sign,
"errors": final_errors,
}

return m
117 changes: 117 additions & 0 deletions coreblocks/func_blocks/fu/fpu/fpu_rounding_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from amaranth import *
from transactron import TModule, Method, def_method
from coreblocks.func_blocks.fu.fpu.fpu_common import (
RoundingModes,
FPUParams,
)


class FPURoudningMethodLayout:
"""FPU Rounding module layouts for methods

Parameters
----------
fpu_params: FPUParams
FPU parameters
"""

def __init__(self, *, fpu_params: FPUParams):
self.rounding_in_layout = [
("sign", 1),
("sig", fpu_params.sig_width),
("exp", fpu_params.exp_width),
("round_bit", 1),
("sticky_bit", 1),
("rounding_mode", RoundingModes),
]
self.rounding_out_layout = [
("sig", fpu_params.sig_width),
("exp", fpu_params.exp_width),
("inexact", 1),
]


class FPUrounding(Elaboratable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class FPUrounding(Elaboratable):
class FPURounding(Elaboratable):

typo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, my bad. Typo is now fixed.

"""FPU Rounding module

Parameters
----------
fpu_params: FPUParams
FPU parameters

Attributes
----------
rounding_request: Method
Transactional method for initiating rounding of a floating point number.
Takes 'rounding_in_layout' as argument
Returns rounded number and errors as 'rounding_out_layout'
"""

def __init__(self, *, fpu_params: FPUParams):

self.fpu_rounding_params = fpu_params
self.method_layouts = FPURoudningMethodLayout(fpu_params=self.fpu_rounding_params)
self.rounding_request = Method(
i=self.method_layouts.rounding_in_layout,
o=self.method_layouts.rounding_out_layout,
)

def elaborate(self, platform):
m = TModule()

add_one = Signal()
inc_rtnte = Signal()
inc_rtnta = Signal()
inc_rtpi = Signal()
inc_rtmi = Signal()

rounded_sig = Signal(self.fpu_rounding_params.sig_width + 1)
normalised_sig = Signal(self.fpu_rounding_params.sig_width)
rounded_exp = Signal(self.fpu_rounding_params.exp_width)

final_round_bit = Signal()
final_sticky_bit = Signal()

inexact = Signal()

@def_method(m, self.rounding_request)
piotro888 marked this conversation as resolved.
Show resolved Hide resolved
def _(arg):

m.d.av_comb += inc_rtnte.eq(
(arg.rounding_mode == RoundingModes.ROUND_NEAREST_EVEN)
& (arg.round_bit & (arg.sticky_bit | arg.sig[0]))
)
m.d.av_comb += inc_rtnta.eq((arg.rounding_mode == RoundingModes.ROUND_NEAREST_AWAY) & (arg.round_bit))
m.d.av_comb += inc_rtpi.eq(
(arg.rounding_mode == RoundingModes.ROUND_UP) & (~arg.sign & (arg.round_bit | arg.sticky_bit))
)
m.d.av_comb += inc_rtmi.eq(
(arg.rounding_mode == RoundingModes.ROUND_DOWN) & (arg.sign & (arg.round_bit | arg.sticky_bit))
)

m.d.av_comb += add_one.eq(inc_rtmi | inc_rtnta | inc_rtnte | inc_rtpi)

m.d.av_comb += rounded_sig.eq(arg.sig + add_one)

with m.If(rounded_sig[-1]):

m.d.av_comb += normalised_sig.eq(rounded_sig >> 1)
m.d.av_comb += final_round_bit.eq(rounded_sig[0])
m.d.av_comb += final_sticky_bit.eq(arg.round_bit | arg.sticky_bit)
m.d.av_comb += rounded_exp.eq(arg.exp + 1)

with m.Else():
m.d.av_comb += normalised_sig.eq(rounded_sig)
m.d.av_comb += final_round_bit.eq(arg.round_bit)
m.d.av_comb += final_sticky_bit.eq(arg.sticky_bit)
m.d.av_comb += rounded_exp.eq(arg.exp)

m.d.av_comb += inexact.eq(final_round_bit | final_sticky_bit)

return {
"exp": rounded_exp,
"sig": normalised_sig,
"inexact": inexact,
}

return m
Loading