Skip to content

Commit

Permalink
[RTL Thresh] Move weight file generation for runtime writeable weight…
Browse files Browse the repository at this point in the history
…s in separate function
  • Loading branch information
auphelia committed May 10, 2024
1 parent 355bf99 commit 88258ea
Showing 1 changed file with 47 additions and 27 deletions.
74 changes: 47 additions & 27 deletions src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,33 +232,7 @@ def prepare_codegen_rtl_values(self, model):

if self.get_nodeattr("runtime_writeable_weights") == 1:
thresh_file_name = f"{t_path}/memblock.dat"
width_padded = roundup_to_integer_multiple(thresholds.shape[1], 2**o_bitwidth)
thresh_padded = np.zeros((thresholds.shape[0], width_padded))
thresh_padded[: thresholds.shape[0], :n_thres_steps] = thresholds
thresh_stream = []
bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 32)
padding = np.zeros(width_padded, dtype=np.int32)

chan_ind = 0
cf = ch // pe
for fold in range(cf):
for c in range(2 ** (pe - 1).bit_length()):
if (c == 0 or c % pe != 0) and c < pe:
for t in thresh_padded[chan_ind]:
t_packed = pack_innermost_dim_as_hex_string(
[t], wdt, bw_hexdigit, prefix=""
).item()
thresh_stream.append(t_packed)
chan_ind += 1
else:
for z in padding:
t_packed = pack_innermost_dim_as_hex_string(
[z], wdt, bw_hexdigit, prefix=""
).item()
thresh_stream.append(t_packed)
with open(thresh_file_name, "w") as f:
for val in thresh_stream:
f.write(val + "\n")
self.make_weight_file(thresholds, "decoupled", thresh_file_name)

# Identify the module name
code_gen_dict["$MODULE_NAME_AXI_WRAPPER$"] = [
Expand Down Expand Up @@ -532,3 +506,49 @@ def get_verilog_top_module_intf_names(self):
intf_names["axilite"] = ["s_axilite"]

return intf_names

def make_weight_file(self, weights, weight_file_mode, weight_file_name):
"""Produce a file containing given weights (thresholds) in appropriate
format for this layer. This file can be used for either synthesis or
run-time reconfig of weights.
Arguments:
* weights : numpy array with weights to be put into the file
* weight_file_name : filename for the weight file to be generated
"""
thresholds = weights
pe = self.get_nodeattr("PE")
ch = self.get_nodeattr("NumChannels")
output_data_type = self.get_nodeattr("outputDataType") # output precision
o_bitwidth = DataType[output_data_type].bitwidth()
n_thres_steps = 2**o_bitwidth - 1
width_padded = roundup_to_integer_multiple(thresholds.shape[1], 2**o_bitwidth)
thresh_padded = np.zeros((thresholds.shape[0], width_padded))
thresh_padded[: thresholds.shape[0], :n_thres_steps] = thresholds
thresh_stream = []
wdt = self.get_weight_datatype()
bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 32)
padding = np.zeros(width_padded, dtype=np.int32)

chan_ind = 0
cf = ch // pe
for fold in range(cf):
for c in range(2 ** (pe - 1).bit_length()):
if (c == 0 or c % pe != 0) and c < pe:
for t in thresh_padded[chan_ind]:
t_packed = pack_innermost_dim_as_hex_string(
[t], wdt, bw_hexdigit, prefix=""
).item()
thresh_stream.append(t_packed)
chan_ind += 1
else:
for z in padding:
t_packed = pack_innermost_dim_as_hex_string(
[z], wdt, bw_hexdigit, prefix=""
).item()
thresh_stream.append(t_packed)
with open(weight_file_name, "w") as f:
for val in thresh_stream:
f.write(val + "\n")

0 comments on commit 88258ea

Please sign in to comment.