Skip to content

Commit

Permalink
Enable non-narrow weights for DSP48E2. Expose version in core selection.
Browse files Browse the repository at this point in the history
  • Loading branch information
preusser committed May 24, 2024
1 parent 9e2ba5c commit 739d644
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 187 deletions.
148 changes: 98 additions & 50 deletions finn-rtllib/mvu/mvu_4sx4u.sv
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ module mvu_4sx4u #(
int unsigned SIMD,
int unsigned ACCU_WIDTH,

int unsigned VERSION = 1,
int unsigned VERSION = 1, // Version 1 (DSP48E1) *must* commit to NARROW_WEIGHTS
bit SIGNED_ACTIVATIONS = 0,
bit NARROW_WEIGHTS = 0, // Weights from [-7:7] rather than [-8:7]
bit FORCE_BEHAVIORAL = 0
)(
// Global Control
Expand All @@ -62,6 +63,54 @@ module mvu_4sx4u #(
`endif
FORCE_BEHAVIORAL;

//-----------------------------------------------------------------------
// Determine Lane Configuration
typedef struct {
int unsigned OFFSET[4:0];
int unsigned LO_WIDTH[3:0];
int unsigned HI_WIDTH[2:0];
int unsigned LO_WIDTH_MAX; // exluding leftmost lane
int unsigned HI_WIDTH_MAX; // exluding leftmost lane
} slicing_t;
function slicing_t sliceLanes();
automatic slicing_t slicing;

// Determine Lane Offsets
unique case(VERSION)
1: begin
if(!NARROW_WEIGHTS) begin
$error("%m: Need NARROW_WEIGHTS for DSP48E1.");
$finish;
end
slicing.OFFSET = '{ ACCU_WIDTH+21, 21, 14, 7, 0 };
end
2: begin
slicing.OFFSET = NARROW_WEIGHTS?
'{ ACCU_WIDTH+23, 23, 16, 8, 0 } :
'{ ACCU_WIDTH+22, 22, 15, 8, 0 };
end
endcase

// Derive other Lane Attributes
for(int unsigned i = 0; i < 4; i++) begin
automatic int unsigned lw = slicing.OFFSET[i+1] - slicing.OFFSET[i];
slicing.LO_WIDTH[i] = lw;

if(i < 3) begin
automatic int unsigned hw = 1 + $clog2(2**(ACCU_WIDTH-lw-1)+SIMD);
slicing.HI_WIDTH[i] = hw;

if(lw > slicing.LO_WIDTH_MAX) slicing.LO_WIDTH_MAX = lw;
if(hw > slicing.HI_WIDTH_MAX) slicing.HI_WIDTH_MAX = hw;
end
end

return slicing;
endfunction : sliceLanes
localparam slicing_t SLICING = sliceLanes();
localparam int unsigned A_WIDTH = 23 + 2*VERSION; // Width of A datapath

// Compute the count of decendents for all nodes in the reduction trees.
typedef int unsigned leave_load_t[2*SIMD-1];
function leave_load_t init_leave_loads();
automatic leave_load_t res;
Expand All @@ -79,12 +128,6 @@ module mvu_4sx4u #(
assign vld = L[5];

// Stages #1 - #3: DSP Lanes + cross-lane canaries duplicated with SIMD parallelism
localparam int unsigned D[4:0] = // Lane offsets
VERSION == 1? '{ ACCU_WIDTH+21, 21, 14, 7, 0 } :
VERSION == 2? '{ ACCU_WIDTH+23, 23, 16, 8, 0 } :
/* else */ '{ default: 0 };
localparam int unsigned A_WIDTH = 23 + 2*VERSION; // Width of A datapath

localparam int unsigned PIPE_COUNT = (PE+3)/4;
for(genvar c = 0; c < PIPE_COUNT; c++) begin : genPipes

Expand All @@ -102,7 +145,7 @@ module mvu_4sx4u #(
logic [26:0] dd;
logic [ 1:0] xx[3:1];
if(1) begin : blkVectorize
uwire [3:0] ww[PE_END - PE_BEG];
uwire signed [3:0] ww[PE_END - PE_BEG];
for(genvar pe = 0; pe < PE_END - PE_BEG; pe++) begin
assign ww[pe] = w[PE_BEG + pe][s];
if(pe) begin
Expand All @@ -127,15 +170,19 @@ module mvu_4sx4u #(
dd = '0;
aa = '0;
for(int unsigned pe = 0; pe < PE_END - PE_BEG; pe++) begin
dd[D[pe + PE_REM]+:3] = ww[pe];
automatic int unsigned ofs = SLICING.OFFSET[pe + PE_REM];
dd[ofs+:3] = ww[pe];
assert(!NARROW_WEIGHTS || (ww[pe] != -8)) else begin
$warning("Weight of -8 violates NARROW_WEIGHTS commitment.");
end

// The sign of the weights are generally put on the subtracted A port.
// However, when coinciding with the actual sign bit position of the
// multiplier input path, it also goes onto the D input. This prevents
// sign extensions that may happen when a DSP primitive is auto-promoted
// to a newer generation.
if(D[pe + PE_REM]+3 == A_WIDTH-1) dd[D[pe + PE_REM]+3] = ww[pe][3];
else aa[D[pe + PE_REM]+3] = ww[pe][3];
if(ofs+3 == A_WIDTH-1) dd[ofs+3] = ww[pe][3];
else aa[ofs+3] = ww[pe][3];
end
end
end : blkVectorize
Expand Down Expand Up @@ -441,14 +488,14 @@ module mvu_4sx4u #(
X1 <= xx;
X2 <= X1;
foreach(X3[i]) begin
X3[i] <= X2[i] + (L[3]? 2'h0 : pp[D[i]+:2]);
X3[i] <= X2[i] + (L[3]? 2'h0 : pp[SLICING.OFFSET[i]+:2]);
end
end
end

// Derive actual cross-lane overflows
for(genvar i = 0; i < 3; i++) begin
assign h3[s][i] = pp[D[i+1]+:2] - X3[i+1];
assign h3[s][i] = pp[SLICING.OFFSET[i+1]+:2] - X3[i+1];
end
assign p3[s] = pp;

Expand All @@ -457,51 +504,55 @@ module mvu_4sx4u #(
// Stage #4: Cross-SIMD Reduction

// Count leaves reachable from each node
localparam leave_load_t LEAVE_LOAD = SIMD > 1 ? init_leave_loads() : '{ default: 1}; // SIMD=1 requires no adder tree, so zero-ing out, otherwise init_leave_loads ends up in infinite loop
localparam leave_load_t LEAVE_LOAD = SIMD > 1 ? init_leave_loads() : '{ default: 1 }; // SIMD=1 requires no adder tree, so zero-ing out, otherwise init_leave_loads ends up in infinite loop

uwire signed [ACCU_WIDTH-1:0] up4;
uwire signed [$clog2(2**(ACCU_WIDTH-8)+SIMD):0] hi4[3]; // min LO_WIDTH=7
uwire [$clog2(SIMD)+7 :0] lo4[3]; // max LO_WIDTH=8
uwire signed [ SLICING.HI_WIDTH_MAX-1:0] hi4[3];
uwire [$clog2(SIMD)+SLICING.LO_WIDTH_MAX-1:0] lo4[3];
for(genvar i = 0; i < 4; i++) begin
localparam int unsigned LO_WIDTH = D[i+1] - D[i];
localparam int unsigned HI_WIDTH = 1 + $clog2(2**(ACCU_WIDTH-LO_WIDTH-1)+SIMD);

// Conclusive high part accumulation
if(i >= PE_REM && i < 3) begin : genHi
// Adder Tree across all SIMD high contributions, each from [-1:1]
uwire signed [2*SIMD-2:0][$clog2(1+SIMD):0] tree;
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = h3[s][i];
for(genvar n = 0; n < SIMD-1; n++) begin
// Sum truncated to actual maximum bit width at this node
uwire signed [$clog2(1+LEAVE_LOAD[n]):0] s = $signed(tree[2*n+1]) + $signed(tree[2*n+2]);
assign tree[n] = s;
end
if(i < 3) begin : genHi
if(i < PE_REM) assign hi4[i] = '0;
else begin
localparam int unsigned HI_WIDTH = SLICING.HI_WIDTH[i];

// Adder Tree across all SIMD high contributions, each from [-1:1]
uwire signed [2*SIMD-2:0][$clog2(1+SIMD):0] tree;
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = h3[s][i];
for(genvar n = 0; n < SIMD-1; n++) begin
// Sum truncated to actual maximum bit width at this node
uwire signed [$clog2(1+LEAVE_LOAD[n]):0] s = $signed(tree[2*n+1]) + $signed(tree[2*n+2]);
assign tree[n] = s;
end

// High Sideband Accumulation
logic signed [HI_WIDTH-1:0] Hi4 = 0;
always_ff @(posedge clk) begin
if(rst) Hi4 <= 0;
else if(en) begin
automatic logic signed [HI_WIDTH:0] h = $signed(L[4]? 0 : Hi4) + $signed(tree[0]);
assert(h[HI_WIDTH] == h[HI_WIDTH-1]) else begin
$error("%m: Accumulation overflow for ACCU_WIDTH=%0d", ACCU_WIDTH);
$stop;
// High Sideband Accumulation
logic signed [HI_WIDTH-1:0] Hi4 = 0;
always_ff @(posedge clk) begin
if(rst) Hi4 <= 0;
else if(en) begin
automatic logic signed [HI_WIDTH:0] h = $signed(L[4]? 0 : Hi4) + $signed(tree[0]);
assert(h[HI_WIDTH] == h[HI_WIDTH-1]) else begin
$error("%m: Accumulation overflow for ACCU_WIDTH=%0d", ACCU_WIDTH);
$stop;
end
Hi4 <= h;
end
Hi4 <= h;
end
assign hi4[i] = Hi4;

end
assign hi4[i] = Hi4;
end : genHi
else if (i < 3) begin : genHiZero
assign hi4[i] = '0;
end : genHiZero

// Conclusive low part accumulation (all unsigned arithmetic)
if(i >= PE_REM) begin : blkLo
if(i < PE_REM) assign lo4[i] = '0;
else begin : genLo
localparam int unsigned LO_WIDTH = SLICING.LO_WIDTH[i];

// Adder Tree across all SIMD low contributions
localparam int unsigned ROOT_WIDTH = $clog2(1 + SIMD*(2**LO_WIDTH-1));
uwire [2*SIMD-2:0][ROOT_WIDTH-1:0] tree;
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = p3[s][D[i]+:LO_WIDTH];
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = p3[s][SLICING.OFFSET[i]+:LO_WIDTH];
for(genvar n = 0; n < SIMD-1; n++) begin
// Sum truncated to actual maximum bit width at this node
localparam int unsigned NODE_WIDTH = $clog2(1 + LEAVE_LOAD[n]*(2**LO_WIDTH-1));
Expand All @@ -517,10 +568,7 @@ module mvu_4sx4u #(

if(i == 3) assign up4 = Lo4;
else assign lo4[i] = Lo4;
end : blkLo
else begin : blkLoZero
assign lo4[i] = '0;
end : blkLoZero
end : genLo

end

Expand All @@ -530,9 +578,9 @@ module mvu_4sx4u #(
if(rst) Res5 <= '{ default: 0 };
else if(en) begin
Res5[3] <= up4 - hi4[2];
Res5[2] <= $signed({ hi4[2], {(D[3] - D[2]){1'b0}} }) + $signed({ 1'b0, lo4[2] }) - hi4[1];
Res5[1] <= $signed({ hi4[1], {(D[2] - D[1]){1'b0}} }) + $signed({ 1'b0, lo4[1] }) - hi4[0];
Res5[0] <= $signed({ hi4[0], {(D[1] - D[0]){1'b0}} }) + $signed({ 1'b0, lo4[0] });
Res5[2] <= $signed({ hi4[2], {(SLICING.LO_WIDTH[2]){1'b0}} }) + $signed({ 1'b0, lo4[2] }) - hi4[1];
Res5[1] <= $signed({ hi4[1], {(SLICING.LO_WIDTH[1]){1'b0}} }) + $signed({ 1'b0, lo4[1] }) - hi4[0];
Res5[0] <= $signed({ hi4[0], {(SLICING.LO_WIDTH[0]){1'b0}} }) + $signed({ 1'b0, lo4[0] });
end
end

Expand Down
19 changes: 17 additions & 2 deletions finn-rtllib/mvu/mvu_vvu_axi.sv
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ module mvu_vvu_axi #(
int unsigned ACTIVATION_WIDTH,
int unsigned WEIGHT_WIDTH,
int unsigned ACCU_WIDTH,
bit NARROW_WEIGHTS = 0,
bit SIGNED_ACTIVATIONS = 0,

bit PUMPED_COMPUTE = 0,
Expand Down Expand Up @@ -306,8 +307,22 @@ module mvu_vvu_axi #(
.last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a),
.vld(dsp_vld), .p(dsp_p)
);
"mvu_4sx4u":
mvu_4sx4u #(.PE(PE), .SIMD(DSP_SIMD), .ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)) core (
"mvu_4sx4u_dsp48e1":
mvu_4sx4u #(
.PE(PE), .SIMD(DSP_SIMD),
.ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .NARROW_WEIGHTS(NARROW_WEIGHTS),
.VERSION(1), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)
) core (
.clk(dsp_clk), .rst, .en(dsp_en),
.last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a),
.vld(dsp_vld), .p(dsp_p)
);
"mvu_4sx4u_dsp48e2":
mvu_4sx4u #(
.PE(PE), .SIMD(DSP_SIMD),
.ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .NARROW_WEIGHTS(NARROW_WEIGHTS),
.VERSION(2), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)
) core (
.clk(dsp_clk), .rst, .en(dsp_en),
.last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a),
.vld(dsp_vld), .p(dsp_p)
Expand Down
Loading

0 comments on commit 739d644

Please sign in to comment.