Skip to content

Commit

Permalink
[RTL MVU] Update code generation to take dsp variant into account
Browse files Browse the repository at this point in the history
  • Loading branch information
auphelia committed May 27, 2024
1 parent 739d644 commit dbf8ed7
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 89 deletions.
3 changes: 2 additions & 1 deletion finn-rtllib/mvu/mvu_vvu_axi_wrapper.v
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ module $MODULE_NAME_AXI_WRAPPER$ #(
parameter ACTIVATION_WIDTH = $ACTIVATION_WIDTH$,
parameter WEIGHT_WIDTH = $WEIGHT_WIDTH$,
parameter ACCU_WIDTH = $ACCU_WIDTH$,
parameter NARROW_WEIGHTS = $NARROW_WEIGHTS$,
parameter SIGNED_ACTIVATIONS = $SIGNED_ACTIVATIONS$,
parameter SEGMENTLEN = $SEGMENTLEN$,
parameter FORCE_BEHAVIORAL = $FORCE_BEHAVIORAL$,
Expand Down Expand Up @@ -77,7 +78,7 @@ module $MODULE_NAME_AXI_WRAPPER$ #(

mvu_vvu_axi #(
.IS_MVU(IS_MVU), .COMPUTE_CORE(COMPUTE_CORE), .PUMPED_COMPUTE(PUMPED_COMPUTE), .MW(MW), .MH(MH), .PE(PE), .SIMD(SIMD),
.ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH),
.ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), .NARROW_WEIGHTS(NARROW_WEIGHTS),
.SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .SEGMENTLEN(SEGMENTLEN), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)
) inst (
.ap_clk(ap_clk),
Expand Down
38 changes: 21 additions & 17 deletions src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from finn.custom_op.fpgadataflow.matrixvectoractivation import MVAU
from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend
from finn.util.basic import get_rtlsim_trace_depth, make_build_dir
from finn.util.basic import get_dsp_block, get_rtlsim_trace_depth, make_build_dir
from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy

try:
Expand All @@ -55,10 +55,7 @@ def __init__(self, onnx_node, **kwargs):
super().__init__(onnx_node, **kwargs)

def get_nodeattr_types(self):
my_attrs = {
# Flag to indicate if Versal device is targeted
"is_versal": ("i", False, 0, {0, 1}),
}
my_attrs = {}
my_attrs.update(MVAU.get_nodeattr_types(self))
my_attrs.update(RTLBackend.get_nodeattr_types(self))
return my_attrs
Expand Down Expand Up @@ -141,10 +138,11 @@ def dsp_estimation(self):
# multiplication
P = self.get_nodeattr("PE")
Q = self.get_nodeattr("SIMD")
if self.get_nodeattr("is_versal"):
mult_dsp = P * np.ceil(Q / 3)
else:
mult_dsp = np.ceil(P / 4) * Q
# TODO: get dsp block type
# if dsp_block = "DSP58":
# mult_dsp = P * np.ceil(Q / 3)
# else:
mult_dsp = np.ceil(P / 4) * Q
return int(mult_dsp)

def instantiate_ip(self, cmd):
Expand Down Expand Up @@ -186,7 +184,7 @@ def _resolve_segment_len(self, clk):
dsp_chain_len = critical_path_dsps if critical_path_dsps < max_chain_len else max_chain_len
return dsp_chain_len

def _resolve_impl_style(self, fpgapart):
def _resolve_impl_style(self, dsp_block):
# Based on target device and activation/weight-width, choose the
# supported RTL compute core
assert (
Expand All @@ -198,15 +196,15 @@ def _resolve_impl_style(self, fpgapart):

act_width = self.get_input_datatype(0).bitwidth()
weight_width = self.get_input_datatype(1).bitwidth()
is_versal_family = self.get_nodeattr("is_versal")

if is_versal_family:
if dsp_block == "DSP58":
return "mvu_vvu_8sx9_dsp58"
else:
act_width = self.get_input_datatype(0).bitwidth()
weight_width = self.get_input_datatype(1).bitwidth()
if (act_width == 4 and weight_width == 4) and not (is_versal_family):
return "mvu_4sx4u"
if act_width <= 4 and weight_width <= 4:
if dsp_block == "DSP48E1":
return "mvu_4sx4u_dsp48e1"
elif dsp_block == "DSP48E2":
return "mvu_4sx4u_dsp48e2"
else:
return "mvu_8sx8u_dsp48"

Expand All @@ -216,6 +214,11 @@ def generate_hdl(self, model, fpgapart, clk):
self.generate_params(model, code_gen_dir)

template_path, code_gen_dict = self.prepare_codegen_default(fpgapart, clk)
# determine if weights are narrow range and add parameter to code gen dict
weights = model.get_initializer(self.onnx_node.input[1])
wdt = self.get_weight_datatype()
narrow_weights = 0 if np.min(weights) == wdt.min() else 1
code_gen_dict["$NARROW_WEIGHTS$"] = str(narrow_weights)
# add general parameters to dictionary
code_gen_dict["$MODULE_NAME_AXI_WRAPPER$"] = [self.get_verilog_top_module_name()]
# save top module name so we can refer to it after this node has been renamed
Expand Down Expand Up @@ -248,9 +251,10 @@ def generate_hdl(self, model, fpgapart, clk):
def prepare_codegen_default(self, fpgapart, clk):
template_path = os.environ["FINN_ROOT"] + "/finn-rtllib/mvu/mvu_vvu_axi_wrapper.v"

dsp_block = get_dsp_block(fpgapart)
code_gen_dict = {}
code_gen_dict["$IS_MVU$"] = [str(1)]
code_gen_dict["$COMPUTE_CORE$"] = [self._resolve_impl_style(fpgapart)]
code_gen_dict["$COMPUTE_CORE$"] = [self._resolve_impl_style(dsp_block)]
code_gen_dict["$MW$"] = [str(self.get_nodeattr("MW"))]
code_gen_dict["$MH$"] = [str(self.get_nodeattr("MH"))]
code_gen_dict["$PE$"] = [str(self.get_nodeattr("PE"))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@

from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend
from finn.custom_op.fpgadataflow.vectorvectoractivation import VVAU
from finn.util.basic import get_rtlsim_trace_depth, make_build_dir
from finn.util.basic import get_rtlsim_trace_depth, is_versal, make_build_dir
from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
from finn.util.fpgadataflow import is_versal

try:
from pyverilator import PyVerilator
Expand Down
122 changes: 61 additions & 61 deletions src/finn/transformation/fpgadataflow/specialize_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,27 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import warnings
from onnx import helper
from qonnx.core.datatype import DataType
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation

from finn.custom_op.fpgadataflow.hls import custom_op as hls_variants
from finn.custom_op.fpgadataflow.rtl import custom_op as rtl_variants
from finn.util.fpgadataflow import is_versal
from finn.util.basic import get_dsp_block, is_versal


def _determine_impl_style(node, fpgapart):
def _determine_impl_style(node, fpgapart, model):
optype = node.op_type

# check if there is an HLS or RTL variant or both
hls_variant = optype + "_hls" in hls_variants.keys()
rtl_variant = optype + "_rtl" in rtl_variants.keys()

# check if user has specified a preferred_impl_style
inst = getCustomOp(node)
impl_style = inst.get_nodeattr("preferred_impl_style")
node_inst = getCustomOp(node)
impl_style = node_inst.get_nodeattr("preferred_impl_style")

# if impl_style not set, for "simple" layers always try
# to use rtl variant if available
Expand All @@ -55,23 +55,19 @@ def _determine_impl_style(node, fpgapart):
return _dwc_determine_impl_style(node)
if rtl_variant:
if optype == "MVAU":
inp_width_fit = (
DataType[getCustomOp(node).get_nodeattr("inputDataType")].bitwidth() >= 4
)
weight_width_fit = (
DataType[getCustomOp(node).get_nodeattr("weightDataType")].bitwidth() >= 4
)
if inp_width_fit and weight_width_fit and _mvu_rtl_possible(node):
idt = node_inst.get_input_datatype()
wdt = node_inst.get_weight_datatype()
inp_width_fit = idt.bitwidth() >= 4
weight_width_fit = wdt.bitwidth() >= 4
if inp_width_fit and weight_width_fit and _mvu_rtl_possible(node, fpgapart, model):
return "rtl"
else:
return "hls"
elif optype == "VVAU":
inp_width_fit = (
DataType[getCustomOp(node).get_nodeattr("inputDataType")].bitwidth() >= 4
)
weight_width_fit = (
DataType[getCustomOp(node).get_nodeattr("weightDataType")].bitwidth() >= 4
)
idt = node_inst.get_input_datatype()
wdt = node_inst.get_weight_datatype()
inp_width_fit = idt.bitwidth() >= 4
weight_width_fit = wdt.bitwidth() >= 4
if inp_width_fit and weight_width_fit and _vvu_rtl_possible(node, fpgapart):
return "rtl"
else:
Expand Down Expand Up @@ -136,7 +132,7 @@ def _determine_impl_style(node, fpgapart):
# user setting can be fulfilled
return "rtl"
elif optype == "MVAU":
if _mvu_rtl_possible(node):
if _mvu_rtl_possible(node, fpgapart, model):
return "rtl"
else:
warn_str = """There is no RTL variant for %s. The node will automatically be
Expand Down Expand Up @@ -232,56 +228,63 @@ def _swg_hls_possible(node):
return False


def _mvu_rtl_possible(n):
def _mvu_rtl_possible(n, fpgapart, model):
# Checks whether RTL-based MVU is supported
# Currently, for DSP48 we only support computations up to
# 8sx8u (8-bit signed weights x 8-bit (un)signed activations)
# and for DSP58 we support up to 8sx9s. Next to that,
# embedded thresholding functionality is not supported and
# neither binaryxnormode computation.
inp_width_in_range = (
DataType[getCustomOp(n).get_nodeattr("inputDataType")].bitwidth() <= 8
) or (
DataType[getCustomOp(n).get_nodeattr("inputDataType")].bitwidth() == 9
and DataType[getCustomOp(n).get_nodeattr("inputDataType")].min() < 0
)
weight_width_in_range = DataType[getCustomOp(n).get_nodeattr("weightDataType")].bitwidth() <= 8
signed_weights = DataType[getCustomOp(n).get_nodeattr("weightDataType")].min() < 0
no_activation = getCustomOp(n).get_nodeattr("noActivation") == 1
not_binaryxnor_mode = getCustomOp(n).get_nodeattr("binaryXnorMode") == 0
# and for DSP58 we support up to 8sx9s.
# Please note, DSP48E1 does only support narrow range for weights
# Next to that, embedded thresholding functionality is not supported
# and neither binaryxnormode computation.
node_inst = getCustomOp(n)
# first check if no Activation or binary xnor mode and return False
# immediately if one of them is True
no_activation = node_inst.get_nodeattr("noActivation") == 0
not_binaryxnor_mode = node_inst.get_nodeattr("binaryXnorMode") == 1
if no_activation or not_binaryxnor_mode:
return False

return (
inp_width_in_range
and weight_width_in_range
and signed_weights
and no_activation
and not_binaryxnor_mode
)
# check if weights are signed, if not return False
wdt = node_inst.get_weight_datatype()
if not wdt.signed():
return False

# check which dsp block is available on fpga
dsp_block = get_dsp_block(fpgapart)
# check if weights are narrow
weights = model.get_initializer(n.input[1])
narrow_weights = False if np.min(weights) == wdt.min() else True
# if non narrow weights and only DSP48E1 available return False
if not narrow_weights and dsp_block == "DSP48E1":
return False

# if none of the above constraints have been triggered
# we now check if input and weight data types are in range
idt = node_inst.get_input_datatype()
inp_width_in_range = (idt.bitwidth() <= 8) or (idt.bitwidth() == 9 and idt.signed())
weight_width_in_range = wdt.bitwidth() <= 8

return inp_width_in_range and weight_width_in_range


def _vvu_rtl_possible(n, fpgapart):
# Checks whether RTL-based VVU is supported
# Currently, we only support RTL-VVU on DSP58 up to 8sx9s inputs
# (8-bit signed weights x (9-bit signed OR 8-bit (un)signed) activations).
# Next to that, embedded thresholding functionality is not supported.
in_width_in_range = (
DataType[getCustomOp(n).get_nodeattr("inputDataType")].bitwidth() <= 8
) or (
DataType[getCustomOp(n).get_nodeattr("inputDataType")].bitwidth() == 9
and DataType[getCustomOp(n).get_nodeattr("inputDataType")].min() < 0
)
weight_width_in_range = DataType[getCustomOp(n).get_nodeattr("weightDataType")].bitwidth() <= 8
signed_weights = DataType[getCustomOp(n).get_nodeattr("weightDataType")].min() < 0
is_versal_family = is_versal(fpgapart)
no_activation = getCustomOp(n).get_nodeattr("noActivation") == 1
node_inst = getCustomOp(n)
if not node_inst.get_nodeattr("noActivation"):
return False
if not is_versal(fpgapart):
return False

idt = node_inst.get_input_datatype()
wdt = node_inst.get_weight_datatype()
in_width_in_range = (idt.bitwidth() <= 8) or (idt.bitwidth() == 9 and idt.min() < 0)
weight_width_in_range = wdt.bitwidth() <= 8
signed_weights = wdt.min() < 0

return (
in_width_in_range
and weight_width_in_range
and signed_weights
and is_versal_family
and no_activation
)
return in_width_in_range and weight_width_in_range and signed_weights


class SpecializeLayers(Transformation):
Expand All @@ -300,7 +303,7 @@ def apply(self, model):
if not node.domain == "finn.custom_op.fpgadataflow":
continue
node_ind += 1
impl_style = _determine_impl_style(node, self.fpgapart)
impl_style = _determine_impl_style(node, self.fpgapart, model)
optype = node.op_type + "_" + impl_style

new_node = helper.make_node(
Expand All @@ -313,9 +316,6 @@ def apply(self, model):
for attribute in node.attribute:
if attribute.name != "preferred_impl_style":
new_node.attribute.append(attribute)
if new_node.op_type == "MVAU_rtl":
is_versal_family = is_versal(self.fpgapart)
getCustomOp(new_node).set_nodeattr("is_versal", is_versal_family)
graph.node.insert(node_ind, new_node)
# remove old nodes
graph.node.remove(node)
Expand Down
17 changes: 17 additions & 0 deletions src/finn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,20 @@ def memutil(req_mem_spec, primitive_spec):
eff = (req_width * req_depth) / (count * prim_width * prim_depth)
waste = (count * prim_width * prim_depth) - (req_width * req_depth)
return (count, eff, waste)


def is_versal(fpgapart):
"""Returns whether board is part of the Versal family"""
return (
fpgapart[0:4] in ["xcvc", "xcve", "xcvp", "xcvm", "xqvc", "xqvm"]
or fpgapart[0:5] == "xqrvc"
)


def get_dsp_block(fpgapart):
if is_versal(fpgapart):
return "DSP58"
elif fpgapart[2] == "7":
return "DSP48E1"
else:
return "DSP48E2"
8 changes: 0 additions & 8 deletions src/finn/util/fpgadataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,3 @@ def is_rtl_node(node):
is_node = True

return is_node


def is_versal(fpgapart):
"""Returns whether board is part of the Versal family"""
return (
fpgapart[0:4] in ["xcvc", "xcve", "xcvp", "xcvm", "xqvc", "xqvm"]
or fpgapart[0:5] == "xqrvc"
)

0 comments on commit dbf8ed7

Please sign in to comment.