From 739d64468d0d754f6cd1b54045f8ad7af466202e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20B=2E=20Preu=C3=9Fer?= Date: Fri, 24 May 2024 19:40:03 +0100 Subject: [PATCH] Enable non-narrow weights for DSP48E2. Expose version in core selection. --- finn-rtllib/mvu/mvu_4sx4u.sv | 148 ++++++++++------ finn-rtllib/mvu/mvu_vvu_axi.sv | 19 +- finn-rtllib/mvu/tb/mvu_axi_tb.sv | 286 ++++++++++++++++--------------- 3 files changed, 266 insertions(+), 187 deletions(-) diff --git a/finn-rtllib/mvu/mvu_4sx4u.sv b/finn-rtllib/mvu/mvu_4sx4u.sv index 7f3d6961e3..2f2e1c0d23 100644 --- a/finn-rtllib/mvu/mvu_4sx4u.sv +++ b/finn-rtllib/mvu/mvu_4sx4u.sv @@ -36,8 +36,9 @@ module mvu_4sx4u #( int unsigned SIMD, int unsigned ACCU_WIDTH, - int unsigned VERSION = 1, + int unsigned VERSION = 1, // Version 1 (DSP48E1) *must* commit to NARROW_WEIGHTS bit SIGNED_ACTIVATIONS = 0, + bit NARROW_WEIGHTS = 0, // Weights from [-7:7] rather than [-8:7] bit FORCE_BEHAVIORAL = 0 )( // Global Control @@ -62,6 +63,54 @@ module mvu_4sx4u #( `endif FORCE_BEHAVIORAL; + //----------------------------------------------------------------------- + // Determine Lane Configuration + typedef struct { + int unsigned OFFSET[4:0]; + int unsigned LO_WIDTH[3:0]; + int unsigned HI_WIDTH[2:0]; + int unsigned LO_WIDTH_MAX; // exluding leftmost lane + int unsigned HI_WIDTH_MAX; // exluding leftmost lane + } slicing_t; + function slicing_t sliceLanes(); + automatic slicing_t slicing; + + // Determine Lane Offsets + unique case(VERSION) + 1: begin + if(!NARROW_WEIGHTS) begin + $error("%m: Need NARROW_WEIGHTS for DSP48E1."); + $finish; + end + slicing.OFFSET = '{ ACCU_WIDTH+21, 21, 14, 7, 0 }; + end + 2: begin + slicing.OFFSET = NARROW_WEIGHTS? + '{ ACCU_WIDTH+23, 23, 16, 8, 0 } : + '{ ACCU_WIDTH+22, 22, 15, 8, 0 }; + end + endcase + + // Derive other Lane Attributes + for(int unsigned i = 0; i < 4; i++) begin + automatic int unsigned lw = slicing.OFFSET[i+1] - slicing.OFFSET[i]; + slicing.LO_WIDTH[i] = lw; + + if(i < 3) begin + automatic int unsigned hw = 1 + $clog2(2**(ACCU_WIDTH-lw-1)+SIMD); + slicing.HI_WIDTH[i] = hw; + + if(lw > slicing.LO_WIDTH_MAX) slicing.LO_WIDTH_MAX = lw; + if(hw > slicing.HI_WIDTH_MAX) slicing.HI_WIDTH_MAX = hw; + end + end + + return slicing; + endfunction : sliceLanes + localparam slicing_t SLICING = sliceLanes(); + localparam int unsigned A_WIDTH = 23 + 2*VERSION; // Width of A datapath + + // Compute the count of decendents for all nodes in the reduction trees. typedef int unsigned leave_load_t[2*SIMD-1]; function leave_load_t init_leave_loads(); automatic leave_load_t res; @@ -79,12 +128,6 @@ module mvu_4sx4u #( assign vld = L[5]; // Stages #1 - #3: DSP Lanes + cross-lane canaries duplicated with SIMD parallelism - localparam int unsigned D[4:0] = // Lane offsets - VERSION == 1? '{ ACCU_WIDTH+21, 21, 14, 7, 0 } : - VERSION == 2? '{ ACCU_WIDTH+23, 23, 16, 8, 0 } : - /* else */ '{ default: 0 }; - localparam int unsigned A_WIDTH = 23 + 2*VERSION; // Width of A datapath - localparam int unsigned PIPE_COUNT = (PE+3)/4; for(genvar c = 0; c < PIPE_COUNT; c++) begin : genPipes @@ -102,7 +145,7 @@ module mvu_4sx4u #( logic [26:0] dd; logic [ 1:0] xx[3:1]; if(1) begin : blkVectorize - uwire [3:0] ww[PE_END - PE_BEG]; + uwire signed [3:0] ww[PE_END - PE_BEG]; for(genvar pe = 0; pe < PE_END - PE_BEG; pe++) begin assign ww[pe] = w[PE_BEG + pe][s]; if(pe) begin @@ -127,15 +170,19 @@ module mvu_4sx4u #( dd = '0; aa = '0; for(int unsigned pe = 0; pe < PE_END - PE_BEG; pe++) begin - dd[D[pe + PE_REM]+:3] = ww[pe]; + automatic int unsigned ofs = SLICING.OFFSET[pe + PE_REM]; + dd[ofs+:3] = ww[pe]; + assert(!NARROW_WEIGHTS || (ww[pe] != -8)) else begin + $warning("Weight of -8 violates NARROW_WEIGHTS commitment."); + end // The sign of the weights are generally put on the subtracted A port. // However, when coinciding with the actual sign bit position of the // multiplier input path, it also goes onto the D input. This prevents // sign extensions that may happen when a DSP primitive is auto-promoted // to a newer generation. - if(D[pe + PE_REM]+3 == A_WIDTH-1) dd[D[pe + PE_REM]+3] = ww[pe][3]; - else aa[D[pe + PE_REM]+3] = ww[pe][3]; + if(ofs+3 == A_WIDTH-1) dd[ofs+3] = ww[pe][3]; + else aa[ofs+3] = ww[pe][3]; end end end : blkVectorize @@ -441,14 +488,14 @@ module mvu_4sx4u #( X1 <= xx; X2 <= X1; foreach(X3[i]) begin - X3[i] <= X2[i] + (L[3]? 2'h0 : pp[D[i]+:2]); + X3[i] <= X2[i] + (L[3]? 2'h0 : pp[SLICING.OFFSET[i]+:2]); end end end // Derive actual cross-lane overflows for(genvar i = 0; i < 3; i++) begin - assign h3[s][i] = pp[D[i+1]+:2] - X3[i+1]; + assign h3[s][i] = pp[SLICING.OFFSET[i+1]+:2] - X3[i+1]; end assign p3[s] = pp; @@ -457,51 +504,55 @@ module mvu_4sx4u #( // Stage #4: Cross-SIMD Reduction // Count leaves reachable from each node - localparam leave_load_t LEAVE_LOAD = SIMD > 1 ? init_leave_loads() : '{ default: 1}; // SIMD=1 requires no adder tree, so zero-ing out, otherwise init_leave_loads ends up in infinite loop + localparam leave_load_t LEAVE_LOAD = SIMD > 1 ? init_leave_loads() : '{ default: 1 }; // SIMD=1 requires no adder tree, so zero-ing out, otherwise init_leave_loads ends up in infinite loop uwire signed [ACCU_WIDTH-1:0] up4; - uwire signed [$clog2(2**(ACCU_WIDTH-8)+SIMD):0] hi4[3]; // min LO_WIDTH=7 - uwire [$clog2(SIMD)+7 :0] lo4[3]; // max LO_WIDTH=8 + uwire signed [ SLICING.HI_WIDTH_MAX-1:0] hi4[3]; + uwire [$clog2(SIMD)+SLICING.LO_WIDTH_MAX-1:0] lo4[3]; for(genvar i = 0; i < 4; i++) begin - localparam int unsigned LO_WIDTH = D[i+1] - D[i]; - localparam int unsigned HI_WIDTH = 1 + $clog2(2**(ACCU_WIDTH-LO_WIDTH-1)+SIMD); // Conclusive high part accumulation - if(i >= PE_REM && i < 3) begin : genHi - // Adder Tree across all SIMD high contributions, each from [-1:1] - uwire signed [2*SIMD-2:0][$clog2(1+SIMD):0] tree; - for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = h3[s][i]; - for(genvar n = 0; n < SIMD-1; n++) begin - // Sum truncated to actual maximum bit width at this node - uwire signed [$clog2(1+LEAVE_LOAD[n]):0] s = $signed(tree[2*n+1]) + $signed(tree[2*n+2]); - assign tree[n] = s; - end + if(i < 3) begin : genHi + if(i < PE_REM) assign hi4[i] = '0; + else begin + localparam int unsigned HI_WIDTH = SLICING.HI_WIDTH[i]; + + // Adder Tree across all SIMD high contributions, each from [-1:1] + uwire signed [2*SIMD-2:0][$clog2(1+SIMD):0] tree; + for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = h3[s][i]; + for(genvar n = 0; n < SIMD-1; n++) begin + // Sum truncated to actual maximum bit width at this node + uwire signed [$clog2(1+LEAVE_LOAD[n]):0] s = $signed(tree[2*n+1]) + $signed(tree[2*n+2]); + assign tree[n] = s; + end - // High Sideband Accumulation - logic signed [HI_WIDTH-1:0] Hi4 = 0; - always_ff @(posedge clk) begin - if(rst) Hi4 <= 0; - else if(en) begin - automatic logic signed [HI_WIDTH:0] h = $signed(L[4]? 0 : Hi4) + $signed(tree[0]); - assert(h[HI_WIDTH] == h[HI_WIDTH-1]) else begin - $error("%m: Accumulation overflow for ACCU_WIDTH=%0d", ACCU_WIDTH); - $stop; + // High Sideband Accumulation + logic signed [HI_WIDTH-1:0] Hi4 = 0; + always_ff @(posedge clk) begin + if(rst) Hi4 <= 0; + else if(en) begin + automatic logic signed [HI_WIDTH:0] h = $signed(L[4]? 0 : Hi4) + $signed(tree[0]); + assert(h[HI_WIDTH] == h[HI_WIDTH-1]) else begin + $error("%m: Accumulation overflow for ACCU_WIDTH=%0d", ACCU_WIDTH); + $stop; + end + Hi4 <= h; end - Hi4 <= h; end + assign hi4[i] = Hi4; + end - assign hi4[i] = Hi4; end : genHi - else if (i < 3) begin : genHiZero - assign hi4[i] = '0; - end : genHiZero // Conclusive low part accumulation (all unsigned arithmetic) - if(i >= PE_REM) begin : blkLo + if(i < PE_REM) assign lo4[i] = '0; + else begin : genLo + localparam int unsigned LO_WIDTH = SLICING.LO_WIDTH[i]; + // Adder Tree across all SIMD low contributions localparam int unsigned ROOT_WIDTH = $clog2(1 + SIMD*(2**LO_WIDTH-1)); uwire [2*SIMD-2:0][ROOT_WIDTH-1:0] tree; - for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = p3[s][D[i]+:LO_WIDTH]; + for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = p3[s][SLICING.OFFSET[i]+:LO_WIDTH]; for(genvar n = 0; n < SIMD-1; n++) begin // Sum truncated to actual maximum bit width at this node localparam int unsigned NODE_WIDTH = $clog2(1 + LEAVE_LOAD[n]*(2**LO_WIDTH-1)); @@ -517,10 +568,7 @@ module mvu_4sx4u #( if(i == 3) assign up4 = Lo4; else assign lo4[i] = Lo4; - end : blkLo - else begin : blkLoZero - assign lo4[i] = '0; - end : blkLoZero + end : genLo end @@ -530,9 +578,9 @@ module mvu_4sx4u #( if(rst) Res5 <= '{ default: 0 }; else if(en) begin Res5[3] <= up4 - hi4[2]; - Res5[2] <= $signed({ hi4[2], {(D[3] - D[2]){1'b0}} }) + $signed({ 1'b0, lo4[2] }) - hi4[1]; - Res5[1] <= $signed({ hi4[1], {(D[2] - D[1]){1'b0}} }) + $signed({ 1'b0, lo4[1] }) - hi4[0]; - Res5[0] <= $signed({ hi4[0], {(D[1] - D[0]){1'b0}} }) + $signed({ 1'b0, lo4[0] }); + Res5[2] <= $signed({ hi4[2], {(SLICING.LO_WIDTH[2]){1'b0}} }) + $signed({ 1'b0, lo4[2] }) - hi4[1]; + Res5[1] <= $signed({ hi4[1], {(SLICING.LO_WIDTH[1]){1'b0}} }) + $signed({ 1'b0, lo4[1] }) - hi4[0]; + Res5[0] <= $signed({ hi4[0], {(SLICING.LO_WIDTH[0]){1'b0}} }) + $signed({ 1'b0, lo4[0] }); end end diff --git a/finn-rtllib/mvu/mvu_vvu_axi.sv b/finn-rtllib/mvu/mvu_vvu_axi.sv index 6498530113..35325abdf9 100644 --- a/finn-rtllib/mvu/mvu_vvu_axi.sv +++ b/finn-rtllib/mvu/mvu_vvu_axi.sv @@ -55,6 +55,7 @@ module mvu_vvu_axi #( int unsigned ACTIVATION_WIDTH, int unsigned WEIGHT_WIDTH, int unsigned ACCU_WIDTH, + bit NARROW_WEIGHTS = 0, bit SIGNED_ACTIVATIONS = 0, bit PUMPED_COMPUTE = 0, @@ -306,8 +307,22 @@ module mvu_vvu_axi #( .last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a), .vld(dsp_vld), .p(dsp_p) ); - "mvu_4sx4u": - mvu_4sx4u #(.PE(PE), .SIMD(DSP_SIMD), .ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)) core ( + "mvu_4sx4u_dsp48e1": + mvu_4sx4u #( + .PE(PE), .SIMD(DSP_SIMD), + .ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .NARROW_WEIGHTS(NARROW_WEIGHTS), + .VERSION(1), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL) + ) core ( + .clk(dsp_clk), .rst, .en(dsp_en), + .last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a), + .vld(dsp_vld), .p(dsp_p) + ); + "mvu_4sx4u_dsp48e2": + mvu_4sx4u #( + .PE(PE), .SIMD(DSP_SIMD), + .ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .NARROW_WEIGHTS(NARROW_WEIGHTS), + .VERSION(2), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL) + ) core ( .clk(dsp_clk), .rst, .en(dsp_en), .last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a), .vld(dsp_vld), .p(dsp_p) diff --git a/finn-rtllib/mvu/tb/mvu_axi_tb.sv b/finn-rtllib/mvu/tb/mvu_axi_tb.sv index d3532bcfea..f16c40db34 100644 --- a/finn-rtllib/mvu/tb/mvu_axi_tb.sv +++ b/finn-rtllib/mvu/tb/mvu_axi_tb.sv @@ -70,7 +70,7 @@ module mvu_axi_tb(); uwire ap_clk = clk; - // Generate activations + // Generate shared Activations typedef logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] activation_t; typedef activation_t activation_vector_t[SF]; @@ -82,158 +82,174 @@ module mvu_axi_tb(); activation_vector_t ACTIVATIONS = init_ACTIVATIONS(); - struct { - activation_t dat; - logic vld; - logic rdy; - } activations; - - initial begin - activations.vld = 0; - activations.dat = 'X; - @(posedge clk iff ap_rst_n); - - for (int i=0; i= 0; - @(posedge clk); - end while (!(activations.vld === 1 && activations.rdy === 1)); + // Run parallel instances across DSP versions and NARROW_WEIGHTS + bit [2:1][1:0] done = { 2: 2'b00, 1: 2'b01 }; // [ver][narrow] + always_comb begin + if(&done) begin + $display("Test completed."); + $finish; end - - activations.vld <= 0; - activations.dat <= 'x; end - // Generate weights - typedef logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] weight_t; - typedef weight_t weight_matrix_t[NF][SF]; + for(genvar ver = 1; ver <= 2; ver++) begin : genVersion + for(genvar narrow = (ver == 1); narrow <= 1; narrow++) begin : genNarrowWide + + // Activations Feed + struct { + activation_t dat; + logic vld; + logic rdy; + } activations; + + initial begin + activations.vld = 0; + activations.dat = 'X; + @(posedge clk iff ap_rst_n); + + for(int unsigned i = 0; i < SF; i++) begin + while($urandom()%7 == 0) @(posedge clk); + activations.dat <= ACTIVATIONS[i]; + activations.vld <= 1; + @(posedge clk iff activations.rdy); + activations.dat <= 'x; + activations.vld <= 0; + end + end - function weight_matrix_t init_WEIGHTS; - automatic weight_matrix_t res; - std::randomize(res); - for(int unsigned nf = 0; nf < NF; nf++) begin - for(int unsigned sf = 0; sf < SF; sf++) begin - for(int unsigned pe = 0; pe < PE; pe++) begin - for(int unsigned simd = 0; simd < SIMD; simd++) begin - if(res[nf][sf][pe][simd] == (1 << (WEIGHT_WIDTH-1))) begin - res[nf][sf][pe][simd]++; + // Instance-specifc Weights (may be narrow) + typedef logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] weight_t; + typedef weight_t weight_matrix_t[NF][SF]; + + function weight_matrix_t init_WEIGHTS; + automatic weight_matrix_t res; + std::randomize(res); + if(narrow) begin // increment all weights of -8 + for(int unsigned nf = 0; nf < NF; nf++) begin + for(int unsigned sf = 0; sf < SF; sf++) begin + for(int unsigned pe = 0; pe < PE; pe++) begin + for(int unsigned simd = 0; simd < SIMD; simd++) begin + if(res[nf][sf][pe][simd] == (1 << (WEIGHT_WIDTH-1))) begin + res[nf][sf][pe][simd]++; + end + end end end end end - end - return res; - endfunction : init_WEIGHTS; - - weight_matrix_t WEIGHTS = init_WEIGHTS(); - - struct { - weight_t dat; - logic vld; - logic rdy; - } weights; - - initial begin - weights.vld = 0; - weights.dat = 'X; - @(posedge clk iff ap_rst_n); - - weights.vld <= 1; - for (int i=0; i1 - // Hence, we need to 'untangle' the input stream, i.e. [..][SIMD*PE][..] --> [..][PE][SIMD][..] - // Note that for each 'SIMD' (S) and 'PE' (P) element, we have something like: - // (S_0, P_0), ..., (S_0, P_i), (S_1, P_0), ..., (S_1, P_i), ..., (S_i, P_i) which we need to 'untangle' to - // (S_0, P_0), ..., (S_i, P_0), (S_0, P_1), ..., (S_i,, P_1), ..., (S_i, P_i) - for (int i = 0; i < NF; i++) begin - for (int j = 0; j < SF; j++) begin - for (int k = 0; k < PE; k++) begin - for (int l = 0; l < SIMD; l++) begin - if (SIGNED_ACTIVATIONS) - res[i][k] = $signed(res[i][k]) + $signed(a[j][l]) * $signed(w[i][j][k][l]); - else - res[i][k] = $signed(res[i][k]) + $signed({1'b0, a[j][l]}) * $signed(w[i][j][k][l]); + // Function to compute golden output + // a: [SF][SIMD-1:0][ACTIVATION_WIDTH-1:0] + // a: [SF][PE*SIMD-1:0][ACTIVATION_WIDTH-1:0] + // w: [NF][SF][PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] + typedef logic signed [PE-1:0][ACCU_WIDTH-1:0] output_t; + typedef output_t output_vector_t [NF]; + + struct { + output_t dat; + logic vld; + logic rdy; + } outputs; + + function output_vector_t check_output(activation_vector_t a, weight_matrix_t w); + automatic output_vector_t res = '{default: 0}; + // The input stream will have the channels interleaved for VVU when PE>1 + // Hence, we need to 'untangle' the input stream, i.e. [..][SIMD*PE][..] --> [..][PE][SIMD][..] + // Note that for each 'SIMD' (S) and 'PE' (P) element, we have something like: + // (S_0, P_0), ..., (S_0, P_i), (S_1, P_0), ..., (S_1, P_i), ..., (S_i, P_i) which we need to 'untangle' to + // (S_0, P_0), ..., (S_i, P_0), (S_0, P_1), ..., (S_i,, P_1), ..., (S_i, P_i) + for (int i = 0; i < NF; i++) begin + for (int j = 0; j < SF; j++) begin + for (int k = 0; k < PE; k++) begin + for (int l = 0; l < SIMD; l++) begin + if (SIGNED_ACTIVATIONS) + res[i][k] = $signed(res[i][k]) + $signed(a[j][l]) * $signed(w[i][j][k][l]); + else + res[i][k] = $signed(res[i][k]) + $signed({1'b0, a[j][l]}) * $signed(w[i][j][k][l]); + end end end end - end - return res; - endfunction : check_output; - - output_vector_t GOLDEN_OUTPUT = check_output(ACTIVATIONS, WEIGHTS); - - int unsigned NF_CNT = 0; - initial begin - outputs.rdy = 0; - while (NF_CNT < NF) begin - // Loop until both rdy & vld are asserted - do begin - outputs.rdy <= $urandom()%7 >= 0; - @(posedge clk iff ap_rst_n); - end while (!(outputs.rdy === 1 && outputs.vld === 1)); - - // Compare produced outputs against golden outputs - foreach(outputs.dat[i]) begin - assert ($signed(outputs.dat[i]) == $signed(GOLDEN_OUTPUT[NF_CNT][i])) $display(">>> [t=%0t] Test succeeded (NF=%0d)! Computed / GOLDEN = %0d / %0d", $time, NF_CNT, $signed(outputs.dat[i]), $signed(GOLDEN_OUTPUT[NF_CNT][i])); - else begin - $error(">>> [t=%0t] TEST failed (NF=%0d)! Computed / GOLDEN = %0d / %0d", $time, NF_CNT, $signed(outputs.dat[i]), $signed(GOLDEN_OUTPUT[NF_CNT][i])); - $stop; + return res; + endfunction : check_output; + + output_vector_t GOLDEN_OUTPUT = check_output(ACTIVATIONS, WEIGHTS); + initial begin + outputs.rdy = 0; + @(posedge clk iff ap_rst_n); + + for(int unsigned nf = 0; nf < NF; nf++) begin + while($urandom()%13 == 0) @(posedge clk); + outputs.rdy <= 1; + @(posedge clk iff outputs.vld); + outputs.rdy <= 0; + + // Compare produced outputs against golden outputs + foreach(outputs.dat[i]) begin + assert ($signed(outputs.dat[i]) == $signed(GOLDEN_OUTPUT[nf][i])) begin + $display(">>> [t=%0t] Test succeeded (nf=%0d)! Computed / GOLDEN = %0d / %0d", $time, nf, $signed(outputs.dat[i]), $signed(GOLDEN_OUTPUT[nf][i])); + end + else begin + $error(">>> [t=%0t] TEST failed (nf=%0d)! Computed / GOLDEN = %0d / %0d", $time, nf, $signed(outputs.dat[i]), $signed(GOLDEN_OUTPUT[nf][i])); + $stop; + end end end - NF_CNT += 1; + done[ver][narrow] = 1; end - $finish; - end - - // Instantiate DUT - mvu_vvu_axi #( - .IS_MVU(IS_MVU), - .COMPUTE_CORE(COMPUTE_CORE), - .MW(MW), - .MH(MH), - .PE(PE), - .SIMD(SIMD), - .ACTIVATION_WIDTH(ACTIVATION_WIDTH), - .WEIGHT_WIDTH(WEIGHT_WIDTH), - .ACCU_WIDTH(ACCU_WIDTH), - .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), - .SEGMENTLEN(SEGMENTLEN), - .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL), - .M_REG_LUT(M_REG_LUT) - ) - dut ( - .ap_clk, .ap_rst_n, .s_axis_weights_tdata({ {WEIGHT_WIDTH_BA_DELTA{1'b0}}, weights.dat }), .s_axis_weights_tvalid(weights.vld), - .s_axis_weights_tready(weights.rdy), .s_axis_input_tdata({ {ACTIVATION_WIDTH_BA_DELTA{1'b0}}, activations.dat }), .s_axis_input_tvalid(activations.vld), - .s_axis_input_tready(activations.rdy), .m_axis_output_tdata(outputs.dat), .m_axis_output_tvalid(outputs.vld), - .m_axis_output_tready(outputs.rdy) - ); + // Instantiate DUT + mvu_vvu_axi #( + .IS_MVU(IS_MVU), + .COMPUTE_CORE(ver == 1? "mvu_4sx4u_dsp48e1" : "mvu_4sx4u_dsp48e2"), + .MW(MW), + .MH(MH), + .PE(PE), + .SIMD(SIMD), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .ACCU_WIDTH(ACCU_WIDTH), + .NARROW_WEIGHTS(narrow), + .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), + .SEGMENTLEN(SEGMENTLEN), + .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL), + .M_REG_LUT(M_REG_LUT) + ) + dut ( + .ap_clk, .ap_rst_n, .s_axis_weights_tdata({ {WEIGHT_WIDTH_BA_DELTA{1'b0}}, weights.dat }), .s_axis_weights_tvalid(weights.vld), + .s_axis_weights_tready(weights.rdy), .s_axis_input_tdata({ {ACTIVATION_WIDTH_BA_DELTA{1'b0}}, activations.dat }), .s_axis_input_tvalid(activations.vld), + .s_axis_input_tready(activations.rdy), .m_axis_output_tdata(outputs.dat), .m_axis_output_tvalid(outputs.vld), + .m_axis_output_tready(outputs.rdy) + ); + + end : genNarrowWide + end : genVersion endmodule : mvu_axi_tb