Skip to content

Commit

Permalink
Merge pull request Xilinx#1077 from Xilinx/feature/thresholding_types
Browse files Browse the repository at this point in the history
Independent input and threshold bit width for RTL Thresholding
  • Loading branch information
auphelia authored May 15, 2024
2 parents ae87807 + 460c70d commit 6191c42
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 27 deletions.
49 changes: 41 additions & 8 deletions finn-rtllib/thresholding/hdl/thresholding_axi.sv
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
*****************************************************************************/

module thresholding_axi #(
int unsigned N, // output precision
int unsigned K, // input/threshold precision
int unsigned N, // output precision
int unsigned WI, // input precision
int unsigned WT, // threshold precision
int unsigned C = 1, // Channels
int unsigned PE = 1, // Processing Parallelism, requires C = k*PE

Expand Down Expand Up @@ -96,7 +97,7 @@ module thresholding_axi #(
//- AXI Stream - Input --------------
output logic s_axis_tready,
input logic s_axis_tvalid,
input logic [((PE*K+7)/8)*8-1:0] s_axis_tdata,
input logic [((PE*WI+7)/8)*8-1:0] s_axis_tdata,

//- AXI Stream - Output -------------
input logic m_axis_tready,
Expand All @@ -109,13 +110,13 @@ module thresholding_axi #(
uwire cfg_en;
uwire cfg_we;
uwire [ADDR_BITS-3:0] cfg_a;
uwire [K -1:0] cfg_d;
uwire [WT -1:0] cfg_d;
uwire cfg_rack;
uwire [K -1:0] cfg_q;
uwire [WT -1:0] cfg_q;

if(USE_AXILITE) begin
uwire [ADDR_BITS-1:0] cfg_a0;
axi4lite_if #(.ADDR_WIDTH(ADDR_BITS), .DATA_WIDTH(32), .IP_DATA_WIDTH(K)) axi (
axi4lite_if #(.ADDR_WIDTH(ADDR_BITS), .DATA_WIDTH(32), .IP_DATA_WIDTH(WT)) axi (
.aclk(ap_clk), .aresetn(ap_rst_n),

.awready(s_axilite_AWREADY), .awvalid(s_axilite_AWVALID), .awaddr(s_axilite_AWADDR), .awprot('x),
Expand Down Expand Up @@ -143,10 +144,42 @@ module thresholding_axi #(
assign cfg_d = 'x;
end

//-----------------------------------------------------------------------
// Cast Inputs into Threshold Data Type
uwire [PE-1:0][WT-1:0] idat;
for(genvar pe = 0; pe < PE; pe++) begin
if(WT == WI) begin : genCopy
assign idat[pe] = s_axis_tdata[pe*WI+:WI];
end : genCopy
else begin
initial begin
if(FPARG) begin
$error("%m: Can't cast floating-point type.");
$finish;
end
end

if(WT > WI) begin : genWiden
assign idat[pe] = { {(WT-WI){SIGNED? s_axis_tdata[(pe+1)*WI-1] : 1'b0}}, s_axis_tdata[pe*WI+:WI] };
end : genWiden
else begin : genNarrow
// Saturate for clipping inputs
if(!SIGNED) begin
assign idat[pe] = |s_axis_tdata[pe*WI+WT+:WI-WT]? '1 : s_axis_tdata[pe*WI+:WT];
end
else begin
assign idat[pe] =
(s_axis_tdata[pe*WI+WT+:WI-WT] == '1) || (s_axis_tdata[pe*WI+WT+:WI-WT] == '0)? s_axis_tdata[pe*WI+:WT] :
{s_axis_tdata[(pe+1)*WI-1], {(WT-1){!s_axis_tdata[(pe+1)*WI-1]}}};
end
end : genNarrow
end
end

//-----------------------------------------------------------------------
// Kernel Implementation
thresholding #(
.N(N), .K(K), .C(C), .PE(PE),
.N(N), .K(WT), .C(C), .PE(PE),
.SIGNED(SIGNED), .FPARG(FPARG), .BIAS(BIAS),
.THRESHOLDS_PATH(THRESHOLDS_PATH), .USE_CONFIG(USE_AXILITE),
.DEPTH_TRIGGER_URAM(DEPTH_TRIGGER_URAM), .DEPTH_TRIGGER_BRAM(DEPTH_TRIGGER_BRAM),
Expand All @@ -157,7 +190,7 @@ module thresholding_axi #(
.cfg_en, .cfg_we, .cfg_a, .cfg_d,
.cfg_rack, .cfg_q,

.irdy(s_axis_tready), .ivld(s_axis_tvalid), .idat(s_axis_tdata),
.irdy(s_axis_tready), .ivld(s_axis_tvalid), .idat,
.ordy(m_axis_tready), .ovld(m_axis_tvalid), .odat(m_axis_tdata)
);

Expand Down
9 changes: 5 additions & 4 deletions finn-rtllib/thresholding/hdl/thresholding_template_wrapper.v
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
*/

module $MODULE_NAME_AXI_WRAPPER$ #(
parameter N = $N$, // output precision
parameter K = $M$, // input/threshold precision
parameter N = $N$, // output precision
parameter WI = $WI$, // input precision
parameter WT = $WT$, // threshold precision
parameter C = $C$, // Channels
parameter PE = $PE$, // Processing Parallelism, requires C = k*PE

Expand Down Expand Up @@ -87,7 +88,7 @@ module $MODULE_NAME_AXI_WRAPPER$ #(
//- AXI Stream - Input --------------
output in0_V_TREADY,
input in0_V_TVALID,
input [((PE*K+7)/8)*8-1:0] in0_V_TDATA,
input [((PE*WI+7)/8)*8-1:0] in0_V_TDATA,

//- AXI Stream - Output -------------
input out_V_TREADY,
Expand All @@ -96,7 +97,7 @@ module $MODULE_NAME_AXI_WRAPPER$ #(
);

thresholding_axi #(
.N(N), .K(K), .C(C), .PE(PE),
.N(N), .WI(WI), .WT(WT), .C(C), .PE(PE),
.SIGNED(SIGNED),
.FPARG(FPARG),
.BIAS(BIAS),
Expand Down
2 changes: 1 addition & 1 deletion finn-rtllib/thresholding/sim/thresholding_axi_tb.sv
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ module thresholding_axi_tb #(
uwire ovld;
uwire [PE-1:0][N-1:0] odat;

thresholding_axi #(.N(N), .K(K), .C(C), .PE(PE), .SIGNED(0), .USE_AXILITE(1)) dut (
thresholding_axi #(.N(N), .WI(K), .WT(K), .C(C), .PE(PE), .SIGNED(0), .USE_AXILITE(1)) dut (
.ap_clk(clk), .ap_rst_n(!rst),

// Configuration
Expand Down
11 changes: 6 additions & 5 deletions src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ def prepare_codegen_rtl_values(self, model):
# Additionally, increase number of threshold steps to reflect new shape
expected_thresholds = 2**o_bitwidth - 1
n_thres_steps = self.get_nodeattr("numSteps")
wdt = self.get_weight_datatype()
if expected_thresholds != n_thres_steps:
min_val = DataType[input_data_type].min()
min_val = wdt.min()
thresholds = np.insert(thresholds, 0, min_val, axis=1)
bias = bias - 1
n_thres_steps += 1

# add dummy dimension as final dimension (that's what gets packed with next call)
t_expand = np.expand_dims(thresholds, axis=-1)
wdt = self.get_weight_datatype()
bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 4)
t_packed = pack_innermost_dim_as_hex_string(
t_expand,
Expand Down Expand Up @@ -242,9 +242,10 @@ def prepare_codegen_rtl_values(self, model):
i_bitwidth = DataType[input_data_type].bitwidth()

code_gen_dict["$N$"] = [str(o_bitwidth)] # output precision - convert bitwidth to string
code_gen_dict["$M$"] = [
str(i_bitwidth)
] # input/threshold precision - convert bitwidth to string
code_gen_dict["$WT$"] = [
str(wdt.bitwidth())
] # threshold precision - convert bitwidth to string
code_gen_dict["$WI$"] = [str(i_bitwidth)] # input precision - convert bitwidth to string
code_gen_dict["$C$"] = [str(num_channels)] # number of channels
code_gen_dict["$BIAS$"] = [str(bias)] # activation bias value
code_gen_dict["$PE$"] = [str(pe)] # requires C = M*PE
Expand Down
10 changes: 8 additions & 2 deletions src/finn/transformation/fpgadataflow/convert_to_hw_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,16 @@ def apply(self, model):
thl_in_shape = model.get_tensor_shape(thl_input)
thl_thres_shape = model.get_tensor_shape(thl_threshold)
idt = model.get_tensor_datatype(thl_input)

tdt = model.get_tensor_datatype(thl_threshold)
# skip conversion for layers with float input
if not idt.is_integer():
continue
assert tdt.is_integer(), (
node.name
+ """: MultiThreshold cannot be converted
because thresholds are float type. Input data type is integer,
please run RoundAndClipThresholds to convert thresholds to integer."""
)

# check layout of inputs/outputs, and convert if needed
# check layout and convert if necessary
Expand Down Expand Up @@ -253,7 +259,7 @@ def apply(self, model):
PE=pe,
numSteps=thl_thres_shape[1],
inputDataType=idt.name,
weightDataType=idt.name,
weightDataType=tdt.name,
outputDataType=odt.name,
numInputVectors=list(thl_in_shape[:-1]),
ActVal=actval,
Expand Down
25 changes: 18 additions & 7 deletions tests/fpgadataflow/test_fpgadataflow_thresholding.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@


def generate_random_threshold_values(
input_data_type, num_input_channels, num_steps, narrow=False, per_tensor=False
data_type, num_input_channels, num_steps, narrow=False, per_tensor=False
):
if per_tensor:
num_input_channels = 1
if narrow:
num_steps -= 1

return np.random.randint(
input_data_type.min(),
input_data_type.max() + 1,
data_type.min(),
data_type.max() + 1,
(num_input_channels, num_steps),
).astype(np.float32)

Expand All @@ -76,6 +76,7 @@ def sort_thresholds_increasing(thresholds):
def make_single_multithresholding_modelwrapper(
thresholds,
input_data_type,
threshold_data_type,
output_data_type,
activation_bias,
num_input_vecs,
Expand Down Expand Up @@ -115,7 +116,7 @@ def make_single_multithresholding_modelwrapper(
model.set_tensor_datatype("inp", input_data_type)
model.set_tensor_datatype("outp", output_data_type)

model.set_tensor_datatype("thresh", input_data_type)
model.set_tensor_datatype("thresh", threshold_data_type)
model.set_initializer("thresh", thresholds)
return model

Expand All @@ -129,7 +130,15 @@ def make_single_multithresholding_modelwrapper(
],
)
@pytest.mark.parametrize("activation", [DataType["INT4"], DataType["BIPOLAR"]])
@pytest.mark.parametrize("input_data_type", [DataType["INT8"], DataType["UINT8"]])
@pytest.mark.parametrize(
"idt_tdt_cfg",
[
(DataType["INT8"], DataType["INT8"]),
(DataType["INT8"], DataType["INT9"]),
(DataType["UINT8"], DataType["UINT8"]),
(DataType["UINT8"], DataType["UINT9"]),
],
)
@pytest.mark.parametrize("fold", [-1, 1, 2])
@pytest.mark.parametrize("narrow", [True, False])
@pytest.mark.parametrize("per_tensor", [True, False])
Expand All @@ -143,7 +152,7 @@ def test_fpgadataflow_thresholding(
num_input_channels,
num_input_vecs,
activation,
input_data_type,
idt_tdt_cfg,
fold,
narrow,
per_tensor,
Expand All @@ -161,6 +170,7 @@ def test_fpgadataflow_thresholding(
)
if narrow and activation == DataType["BIPOLAR"]:
pytest.skip("Narrow needs to be false with biploar activation.")
input_data_type, threshold_data_type = idt_tdt_cfg
num_steps = activation.get_num_possible_values() - 1

if fold == -1:
Expand All @@ -179,7 +189,7 @@ def test_fpgadataflow_thresholding(

# Generate random thresholds and sort in ascending order
thresholds = generate_random_threshold_values(
input_data_type, num_input_channels, num_steps, narrow, per_tensor
threshold_data_type, num_input_channels, num_steps, narrow, per_tensor
)

# provide non-decreasing/ascending thresholds
Expand All @@ -189,6 +199,7 @@ def test_fpgadataflow_thresholding(
model = make_single_multithresholding_modelwrapper(
thresholds,
input_data_type,
threshold_data_type,
output_data_type,
activation_bias,
num_input_vecs,
Expand Down

0 comments on commit 6191c42

Please sign in to comment.