From 88258eaf4f7eb1751a2059b5f37e0c84c344c432 Mon Sep 17 00:00:00 2001 From: auphelia Date: Fri, 10 May 2024 15:55:52 +0100 Subject: [PATCH] [RTL Thresh] Move weight file generation for runtime writeable weights in separate function --- .../fpgadataflow/rtl/thresholding_rtl.py | 74 ++++++++++++------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py index 4541802e19..6970cde167 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py @@ -232,33 +232,7 @@ def prepare_codegen_rtl_values(self, model): if self.get_nodeattr("runtime_writeable_weights") == 1: thresh_file_name = f"{t_path}/memblock.dat" - width_padded = roundup_to_integer_multiple(thresholds.shape[1], 2**o_bitwidth) - thresh_padded = np.zeros((thresholds.shape[0], width_padded)) - thresh_padded[: thresholds.shape[0], :n_thres_steps] = thresholds - thresh_stream = [] - bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 32) - padding = np.zeros(width_padded, dtype=np.int32) - - chan_ind = 0 - cf = ch // pe - for fold in range(cf): - for c in range(2 ** (pe - 1).bit_length()): - if (c == 0 or c % pe != 0) and c < pe: - for t in thresh_padded[chan_ind]: - t_packed = pack_innermost_dim_as_hex_string( - [t], wdt, bw_hexdigit, prefix="" - ).item() - thresh_stream.append(t_packed) - chan_ind += 1 - else: - for z in padding: - t_packed = pack_innermost_dim_as_hex_string( - [z], wdt, bw_hexdigit, prefix="" - ).item() - thresh_stream.append(t_packed) - with open(thresh_file_name, "w") as f: - for val in thresh_stream: - f.write(val + "\n") + self.make_weight_file(thresholds, "decoupled", thresh_file_name) # Identify the module name code_gen_dict["$MODULE_NAME_AXI_WRAPPER$"] = [ @@ -532,3 +506,49 @@ def get_verilog_top_module_intf_names(self): intf_names["axilite"] = ["s_axilite"] return intf_names + + def make_weight_file(self, weights, weight_file_mode, weight_file_name): + """Produce a file containing given weights (thresholds) in appropriate + format for this layer. This file can be used for either synthesis or + run-time reconfig of weights. + + Arguments: + + * weights : numpy array with weights to be put into the file + * weight_file_name : filename for the weight file to be generated + + """ + thresholds = weights + pe = self.get_nodeattr("PE") + ch = self.get_nodeattr("NumChannels") + output_data_type = self.get_nodeattr("outputDataType") # output precision + o_bitwidth = DataType[output_data_type].bitwidth() + n_thres_steps = 2**o_bitwidth - 1 + width_padded = roundup_to_integer_multiple(thresholds.shape[1], 2**o_bitwidth) + thresh_padded = np.zeros((thresholds.shape[0], width_padded)) + thresh_padded[: thresholds.shape[0], :n_thres_steps] = thresholds + thresh_stream = [] + wdt = self.get_weight_datatype() + bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 32) + padding = np.zeros(width_padded, dtype=np.int32) + + chan_ind = 0 + cf = ch // pe + for fold in range(cf): + for c in range(2 ** (pe - 1).bit_length()): + if (c == 0 or c % pe != 0) and c < pe: + for t in thresh_padded[chan_ind]: + t_packed = pack_innermost_dim_as_hex_string( + [t], wdt, bw_hexdigit, prefix="" + ).item() + thresh_stream.append(t_packed) + chan_ind += 1 + else: + for z in padding: + t_packed = pack_innermost_dim_as_hex_string( + [z], wdt, bw_hexdigit, prefix="" + ).item() + thresh_stream.append(t_packed) + with open(weight_file_name, "w") as f: + for val in thresh_stream: + f.write(val + "\n")