Skip to content

Commit

Permalink
Merge pull request Xilinx#1072 from Xilinx/bugfix/rtl_thresholding
Browse files Browse the repository at this point in the history
Bugfix RTL Thresholding
  • Loading branch information
auphelia authored May 13, 2024
2 parents 39fb885 + 9e32c81 commit 43fc12b
Show file tree
Hide file tree
Showing 5 changed files with 483 additions and 572 deletions.
67 changes: 29 additions & 38 deletions src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,23 @@ def prepare_codegen_rtl_values(self, model):
o_bitwidth = DataType[output_data_type].bitwidth()

# The RTL expects 2^N-1 thresholds, but narrow range quantization will result in
# one less threshold, prepending a dummy threshold and reducing bias by 1 to compensate.
# one less threshold, prepending a dummy threshold (minimal possible value determined by
# input data type) and decrease the bias by 1.
# Additionally, increase number of threshold steps to reflect new shape
expected_thresholds = 2**o_bitwidth - 1
n_thres_steps = self.get_nodeattr("numSteps")
if expected_thresholds != n_thres_steps and DataType[input_data_type].signed() is not True:
min_val = np.amin(thresholds, axis=1)
if expected_thresholds != n_thres_steps:
min_val = DataType[input_data_type].min()
thresholds = np.insert(thresholds, 0, min_val, axis=1)
bias = bias - 1
n_thres_steps += 1

# add dummy dimension as final dimension (that's what gets packed with next call)
thresholds = np.expand_dims(thresholds, axis=-1)
t_expand = np.expand_dims(thresholds, axis=-1)
wdt = self.get_weight_datatype()
bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 4)
t_packed = pack_innermost_dim_as_hex_string(
thresholds,
t_expand,
wdt,
bw_hexdigit,
prefix="",
Expand All @@ -199,8 +202,8 @@ def prepare_codegen_rtl_values(self, model):
num_channels = self.get_nodeattr("NumChannels") # number of channels

# If a single threshold value is found, broadcast the value
expected_shape = (num_channels, n_thres_steps)
if t_packed.shape == (1, 1):
expected_shape = (num_channels, expected_thresholds)
if t_packed.shape != expected_shape:
t_packed = np.broadcast_to(t_packed, expected_shape)

channel_fold = int(num_channels / pe)
Expand All @@ -224,6 +227,10 @@ def prepare_codegen_rtl_values(self, model):
f.write(val + "\n")
code_gen_dict["$THRESHOLDS_PATH$"] = ['"./%s_"' % self.onnx_node.name]

if self.get_nodeattr("runtime_writeable_weights") == 1:
thresh_file_name = f"{t_path}/memblock.dat"
self.make_weight_file(thresholds, "decoupled", thresh_file_name)

# Identify the module name
code_gen_dict["$MODULE_NAME_AXI_WRAPPER$"] = [
self.get_verilog_top_module_name() + "_axi_wrapper"
Expand Down Expand Up @@ -255,7 +262,6 @@ def prepare_codegen_rtl_values(self, model):
o_bits = 1 + math.ceil(
math.log2(-bias if -bias >= 2 ** (o_bitwidth - 1) else 2**o_bitwidth + bias)
)

code_gen_dict["$O_BITS$"] = [str(int(o_bits))]

rt_weights = self.get_nodeattr("runtime_writeable_weights")
Expand Down Expand Up @@ -322,10 +328,6 @@ def generate_hdl(self, model, fpgapart, clk):
# by PyVerilator and IPI generation
self.set_nodeattr("gen_top_module", code_gen_dict["$TOP_MODULE$"][0])

weights = model.get_initializer(self.onnx_node.input[1])
weights_fname = f"{code_gen_dir}/memblock.dat"
self.make_weight_file(weights, "decoupled", weights_fname)

for rtl_file_path in self.get_rtl_file_paths():
# read in original RTL template file
template_data = self.get_rtl_template_data(rtl_file_path)
Expand Down Expand Up @@ -513,27 +515,16 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name):
* weight_file_name : filename for the weight file to be generated
"""
threshold_tensor = self.get_hw_compatible_threshold_tensor(weights)
tdt = self.get_weight_datatype()
assert np.vectorize(tdt.allowed)(
threshold_tensor
).all(), "Thresholds can't be expressed with type %s" % str(tdt)

thresholds = weights
pe = self.get_nodeattr("PE")
ch = self.get_nodeattr("NumChannels")
n_thres_steps = self.get_nodeattr("numSteps")

# If a single threshold value is found, broadcast the value
n_thres_steps = self.get_nodeattr("numSteps")
expected_shape = (ch, n_thres_steps)
if weights.shape == (1, 1):
weights = np.broadcast_to(weights, expected_shape)

odt = self.get_output_datatype().bitwidth()
width_padded = roundup_to_integer_multiple(weights.shape[1], 2**odt)
weight_padded = np.zeros((weights.shape[0], width_padded))
weight_padded[: weights.shape[0], :n_thres_steps] = weights
weight_stream = []
output_data_type = self.get_nodeattr("outputDataType") # output precision
o_bitwidth = DataType[output_data_type].bitwidth()
n_thres_steps = 2**o_bitwidth - 1
width_padded = roundup_to_integer_multiple(thresholds.shape[1], 2**o_bitwidth)
thresh_padded = np.zeros((thresholds.shape[0], width_padded))
thresh_padded[: thresholds.shape[0], :n_thres_steps] = thresholds
thresh_stream = []
wdt = self.get_weight_datatype()
bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 32)
padding = np.zeros(width_padded, dtype=np.int32)
Expand All @@ -543,18 +534,18 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name):
for fold in range(cf):
for c in range(2 ** (pe - 1).bit_length()):
if (c == 0 or c % pe != 0) and c < pe:
for w in weight_padded[chan_ind]:
w_packed = pack_innermost_dim_as_hex_string(
[w], wdt, bw_hexdigit, prefix=""
for t in thresh_padded[chan_ind]:
t_packed = pack_innermost_dim_as_hex_string(
[t], wdt, bw_hexdigit, prefix=""
).item()
weight_stream.append(w_packed)
thresh_stream.append(t_packed)
chan_ind += 1
else:
for z in padding:
w_packed = pack_innermost_dim_as_hex_string(
t_packed = pack_innermost_dim_as_hex_string(
[z], wdt, bw_hexdigit, prefix=""
).item()
weight_stream.append(w_packed)
thresh_stream.append(t_packed)
with open(weight_file_name, "w") as f:
for val in weight_stream:
for val in thresh_stream:
f.write(val + "\n")
6 changes: 2 additions & 4 deletions src/finn/custom_op/fpgadataflow/thresholding.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,23 +242,21 @@ def execute_node(self, context, graph):
node = self.onnx_node
inp_values = context[node.input[0]]
th_val = context[node.input[1]]
out_bias = self.get_nodeattr("ActVal")
# MT expects inputs to be in the shape (N,C,H,W) or (N, C)
# if 4D then input values in context are (N,H,W,C) and need to
# be transposed.
# if 2D then inputs can be passed directly to MT function
is_4d = len(inp_values.shape) == 4
if is_4d:
inp_values = np.transpose(inp_values, (0, 3, 1, 2))
y = multithreshold(inp_values, th_val)
y = multithreshold(inp_values, th_val, out_bias=out_bias)
if is_4d:
y = y.transpose(0, 2, 3, 1)
act = DataType[self.get_nodeattr("outputDataType")]
if act == DataType["BIPOLAR"]:
# binary to bipolar
y = 2 * y - 1
else:
# signed offset
y += act.min()
context[node.output[0]] = y

def calc_tmem(self):
Expand Down
205 changes: 0 additions & 205 deletions tests/fpgadataflow/test_convert_to_hw_thresholding.py

This file was deleted.

Loading

0 comments on commit 43fc12b

Please sign in to comment.