Skip to content

Commit

Permalink
[mvu vvu axi]: reworked flow control and backpressure handling by tpr…
Browse files Browse the repository at this point in the history
…eusser
  • Loading branch information
mmrahorovic committed Jan 11, 2024
1 parent 44f6e0f commit 9b2cceb
Showing 1 changed file with 61 additions and 69 deletions.
130 changes: 61 additions & 69 deletions finn-rtllib/mvu/mvu_vvu_axi.sv
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ module mvu_vvu_axi #(
// Safely deducible parameters
localparam int unsigned WEIGHT_STREAM_WIDTH = PE * SIMD * WEIGHT_WIDTH,
localparam int unsigned WEIGHT_STREAM_WIDTH_BA = (WEIGHT_STREAM_WIDTH + 7) / 8 * 8,
localparam int unsigned INPUT_STREAM_WIDTH = (IS_MVU ? 1 : PE) * SIMD * ACTIVATION_WIDTH,
localparam int unsigned INPUT_STREAM_WIDTH = SIMD * ACTIVATION_WIDTH,
localparam int unsigned INPUT_STREAM_WIDTH_BA = (INPUT_STREAM_WIDTH + 7) / 8 * 8,
localparam int unsigned OUTPUT_STREAM_WIDTH = PE * ACCU_WIDTH,
localparam int unsigned OUTPUT_STREAM_WIDTH_BA = (OUTPUT_STREAM_WIDTH + 7) / 8 * 8,
localparam int unsigned SF = MW / SIMD,
localparam int unsigned NF = IS_MVU ? MH / PE : 1
localparam int unsigned NF = MH / PE
)
(
// Global Control
Expand Down Expand Up @@ -119,81 +119,73 @@ module mvu_vvu_axi #(
$finish;
end
end
if (!IS_MVU) begin
if (COMPUTE_CORE != "mvu_vvu_8sx9_dsp58") begin
$error("VVU only supported on DSP58");
$finish;
end
end
end

uwire clk = ap_clk;
uwire rst = !ap_rst_n;

typedef logic [INPUT_STREAM_WIDTH-1 : 0] mvauin_t;

uwire mvauin_t amvau;
//- Replay to Accommodate Neuron Fold -----------------------------------
typedef logic [PE*SIMD-1:0][ACTIVATION_WIDTH-1:0] mvu_flatin_t;
uwire mvu_flatin_t amvau;
uwire alast;
uwire afin;
uwire avld;
uwire ardy;

replay_buffer #(.LEN(SF), .REP(NF), .W($bits(mvauin_t))) activation_replay (
replay_buffer #(.LEN(SF), .REP(NF), .W($bits(mvu_flatin_t))) activation_replay (
.clk, .rst,
.ivld(s_axis_input_tvalid), .irdy(s_axis_input_tready), .idat(mvauin_t'(s_axis_input_tdata)),
.ivld(s_axis_input_tvalid), .irdy(s_axis_input_tready), .idat(mvu_flatin_t'(s_axis_input_tdata)),
.ovld(avld), .ordy(ardy), .odat(amvau), .olast(alast), .ofin(afin)
);

//-------------------- Input control --------------------\\
//- Unflatten inputs into structured matrices ---------------------------
typedef logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH -1:0] mvu_w_t;
typedef logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] mvu_a_t;

uwire mvu_w_t mvu_w = s_axis_weights_tdata;
uwire mvu_a_t mvu_a = amvau;

//- Flow Control Bracket around Compute Core ----------------------------
uwire en;
uwire istb = avld && s_axis_weights_tvalid;
assign ardy = en && s_axis_weights_tvalid;
assign s_axis_weights_tready = en && avld;

//-------------------- Core MVU/VVU --------------------\\
uwire ovld;
uwire [PE-1:0][ACCU_WIDTH-1:0] odat;
uwire mvauin_t amvau_i;
typedef logic [WEIGHT_STREAM_WIDTH-1 : 0] mvauin_weight_t;

if (IS_MVU) begin : genMVUInput
assign amvau_i = amvau;
end : genMVUInput
else begin : genVVUInput
// The input stream will have the channels interleaved for VVU when PE>1
// Hence, we need to 'untangle' the input stream, i.e. [..][SIMD*PE][..] --> [..][PE][SIMD][..]
// Note that for each 'SIMD' (S) and 'PE' (P) element, we have something like:
// (S_0, P_0), ..., (S_0, P_i), (S_1, P_0), ..., (S_1, P_i), ..., (S_i, P_i) which we need to 'untangle' to
// (S_0, P_0), ..., (S_i, P_0), (S_0, P_1), ..., (S_i,, P_1), ..., (S_i, P_i)
localparam int num_of_elements = PE*SIMD;
for (genvar i=0; i<num_of_elements; i++) begin : genRewire
assign amvau_i[i*ACTIVATION_WIDTH +: ACTIVATION_WIDTH] = (PE > 1) ?
amvau[(i/SIMD + (i*PE % num_of_elements) + 1) * ACTIVATION_WIDTH -1: (i/SIMD + (i*PE % num_of_elements)) * ACTIVATION_WIDTH]
: amvau[i*ACTIVATION_WIDTH +: ACTIVATION_WIDTH];
end : genRewire
end : genVVUInput
//- Instantiate compute core ----------------------------
typedef logic [PE-1:0][ACCU_WIDTH-1:0] dsp_p_t;
uwire dsp_vld;
uwire dsp_p_t dsp_p;

uwire dsp_clk = ap_clk;
uwire dsp_en = en;
uwire dsp_last = alast && avld;
uwire dsp_zero = !istb;
uwire mvu_w_t dsp_w = mvu_w;
uwire mvu_a_t dsp_a = mvu_a;
uwire ovld = dsp_vld;
uwire dsp_p_t odat = dsp_p;

case(COMPUTE_CORE)
"mvu_vvu_8sx9_dsp58":
mvu_vvu_8sx9_dsp58 #(.IS_MVU(IS_MVU), .PE(PE), .SIMD(SIMD), .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH),
.ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .SEGMENTLEN(SEGMENTLEN),
.FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)) core (
.clk, .rst, .en,
.last(alast && avld), .zero(!istb), .w(s_axis_weights_tdata), .a(amvau_i),
.vld(ovld), .p(odat)
.clk(dsp_clk), .rst, .en(dsp_en),
.last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a),
.vld(dsp_vld), .p(dsp_p)
);
"mvu_4sx4u":
mvu_4sx4u #(.PE(PE), .SIMD(SIMD), .ACCU_WIDTH(ACCU_WIDTH), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)) core (
.clk, .rst, .en,
.last(alast && avld), .zero(!istb), .w(mvauin_weight_t'(s_axis_weights_tdata)), .a(amvau_i),
.vld(ovld), .p(odat)
.clk(dsp_clk), .rst, .en(dsp_en),
.last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a),
.vld(dsp_vld), .p(dsp_p)
);
"mvu_8sx8u_dsp48":
mvu_8sx8u_dsp48 #(.PE(PE), .SIMD(SIMD), .ACCU_WIDTH(ACCU_WIDTH), .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH),
.SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .FORCE_BEHAVIORAL(FORCE_BEHAVIORAL)) core (
.clk, .rst, .en,
.last(alast && avld), .zero(!istb), .w(mvauin_weight_t'(s_axis_weights_tdata)), .a(amvau_i),
.vld(ovld), .p(odat)
.clk(dsp_clk), .rst, .en(dsp_en),
.last(dsp_last), .zero(dsp_zero), .w(dsp_w), .a(dsp_a),
.vld(dsp_vld), .p(dsp_p)
);
default: initial begin
$error("Unrecognized COMPUTE_CORE '%s'", COMPUTE_CORE);
Expand All @@ -202,41 +194,41 @@ module mvu_vvu_axi #(
endcase

//-------------------- Output register slice --------------------\\
// Make `en`computation independent from external inputs.
// Drive all outputs from registers.
struct packed {
logic vld;
logic rdy;
logic [PE-1:0][ACCU_WIDTH-1:0] dat;
} A = '{ vld: 0, default: 'x};

assign en = !A.vld || !ovld;

uwire b_load;
always_ff @(posedge clk) begin
if(rst) A <= '{ vld: 0, default: 'x };
else if(!A.vld || b_load) begin
A.vld <= ovld && en;
for(int unsigned i = 0; i < PE; i++) begin
// CR-1148862:
// A.dat[i] <= odat[i];
automatic logic [ACCU_WIDTH-1:0] v = odat[i];
A.dat[i] <= v[ACCU_WIDTH-1:0];
end
end
end

} A = '{ rdy: 1, default: 'x }; // side-step register used when encountering backpressure
struct packed {
logic vld;
logic [PE-1:0][ACCU_WIDTH-1:0] dat;
} B = '{ vld: 0, default: 'x};
} B = '{ vld: 0, default: 'x }; // ultimate output register

assign en = A.rdy;
uwire b_load = !B.vld || m_axis_output_tready;

assign b_load = !B.vld || m_axis_output_tready;
always_ff @(posedge clk) begin
if(rst) B <= '{ vld: 0, default: 'x };
if(rst) begin
A <= '{ rdy: 1, default: 'x };
B <= '{ vld: 0, default: 'x };
end
else begin
if(b_load) B <= '{ vld: A.vld, dat: A.dat};
if(A.rdy) A.dat <= odat;
A.rdy <= (A.rdy && !ovld) || b_load;

if(b_load) begin
B <= '{
vld: ovld || !A.rdy,
dat: A.rdy? odat : A.dat
};
end
end
end

assign m_axis_output_tvalid = B.vld;
// Why would we need a sign extension here potentially creating a higher signal load into the next FIFO?
// These extra bits should never be used. Why not 'x them out?
assign m_axis_output_tdata = { {(OUTPUT_STREAM_WIDTH_BA-OUTPUT_STREAM_WIDTH){B.dat[PE-1][ACCU_WIDTH-1]}}, B.dat};


endmodule : mvu_vvu_axi

0 comments on commit 9b2cceb

Please sign in to comment.