From 1dd118b146b3310daea3835c67cfa7c102631992 Mon Sep 17 00:00:00 2001 From: auphelia Date: Wed, 29 May 2024 13:53:25 +0100 Subject: [PATCH] [RTL MVAU] Bring back is_versal node attribute for resource estimations --- .../fpgadataflow/rtl/matrixvectoractivation_rtl.py | 14 ++++++++------ .../fpgadataflow/specialize_layers.py | 3 +++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py index a6a8e72bdf..d307efe988 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py @@ -55,7 +55,10 @@ def __init__(self, onnx_node, **kwargs): super().__init__(onnx_node, **kwargs) def get_nodeattr_types(self): - my_attrs = {} + my_attrs = { + # Flag to indicate if Versal device is targeted + "is_versal": ("i", False, 0, {0, 1}), + } my_attrs.update(MVAU.get_nodeattr_types(self)) my_attrs.update(RTLBackend.get_nodeattr_types(self)) return my_attrs @@ -138,11 +141,10 @@ def dsp_estimation(self): # multiplication P = self.get_nodeattr("PE") Q = self.get_nodeattr("SIMD") - # TODO: get dsp block type - # if dsp_block = "DSP58": - # mult_dsp = P * np.ceil(Q / 3) - # else: - mult_dsp = np.ceil(P / 4) * Q + if self.get_nodeattr("is_versal"): + mult_dsp = P * np.ceil(Q / 3) + else: + mult_dsp = np.ceil(P / 4) * Q return int(mult_dsp) def instantiate_ip(self, cmd): diff --git a/src/finn/transformation/fpgadataflow/specialize_layers.py b/src/finn/transformation/fpgadataflow/specialize_layers.py index dbcadd1df5..9a88d34787 100644 --- a/src/finn/transformation/fpgadataflow/specialize_layers.py +++ b/src/finn/transformation/fpgadataflow/specialize_layers.py @@ -316,6 +316,9 @@ def apply(self, model): for attribute in node.attribute: if attribute.name != "preferred_impl_style": new_node.attribute.append(attribute) + if new_node.op_type == "MVAU_rtl": + is_versal_family = is_versal(self.fpgapart) + getCustomOp(new_node).set_nodeattr("is_versal", is_versal_family) graph.node.insert(node_ind, new_node) # remove old nodes graph.node.remove(node)