Skip to content

Commit

Permalink
Apply to 1x1 kernel, simplify logic, fix edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
fpjentzsch committed Nov 21, 2023
1 parent 4c80cf8 commit b89dd62
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 37 deletions.
14 changes: 1 addition & 13 deletions finn-rtllib/swg/swg_template_parallel.sv
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ module $TOP_MODULE_NAME$_impl #(
// counters/address registers
logic signed [$clog2(LAST_READ_ELEM+1)+1-1:0] Newest_buffered_elem = -1;
logic [$clog2(LAST_READ_ELEM+1)+1-1:0] Current_elem = FIRST_WRITE_ELEM;
logic [$clog2(LAST_READ_ELEM+1)+1-1:0] First_elem_next_window = 0;

// control registers/signals
logic Writing_done = 0;
Expand All @@ -146,13 +145,7 @@ module $TOP_MODULE_NAME$_impl #(
uwire write_blocked = write_cmd && !out_V_V_TREADY && !Write_done;

uwire reading_done = Newest_buffered_elem == LAST_READ_ELEM;
uwire read_cmd =
!reading_done && ( // if there is still an input element left to read
Writing_done || ( // if writing is done (e.g. for skipped rows at FM end due to stride)
$signed(((Newest_buffered_elem - ($signed(BUF_ELEM_TOTAL) - 1)))) < $signed(First_elem_next_window) &&
$signed(((Newest_buffered_elem - ($signed(BUF_ELEM_TOTAL) - 1)))) < $signed(Current_elem)
) // (over-)write to buffer if oldest buffered element will no longer be needed
);
uwire read_cmd = !reading_done && (Writing_done || Newest_buffered_elem <= $signed(Current_elem));
uwire read_ok = read_cmd && in0_V_V_TVALID && !write_blocked;

// includes waiting on W if W-only cycle: wait only on W no R/W to wait for
Expand Down Expand Up @@ -186,7 +179,6 @@ module $TOP_MODULE_NAME$_impl #(
if(!ap_rst_n) begin
Newest_buffered_elem <= -1;
Current_elem <= FIRST_WRITE_ELEM;
First_elem_next_window <= 0;
Writing_done <= 0;
end
else begin
Expand All @@ -199,14 +191,11 @@ module $TOP_MODULE_NAME$_impl #(
// todo: allow for read overlapping between feature maps (i.e., reading first elements from next FM while still writing last window of current FM)
Newest_buffered_elem <= -1;
Current_elem <= FIRST_WRITE_ELEM;
First_elem_next_window <= 0;
Writing_done <= 0;
end
end

if (write_ok) begin
First_elem_next_window <= First_elem_next_window + tail_incr;

// check if this is the last write cycle (Writing_done will be true afterwards)
if (Current_elem == LAST_WRITE_ELEM) begin
Writing_done <= 1;
Expand All @@ -215,7 +204,6 @@ module $TOP_MODULE_NAME$_impl #(
// start processing of next FM if reading is done already, or completes in the same cycle
Newest_buffered_elem <= -1;
Current_elem <= FIRST_WRITE_ELEM;
First_elem_next_window <= 0;
Writing_done <= 0;
end
end
Expand Down
32 changes: 12 additions & 20 deletions src/finn/custom_op/fpgadataflow/convolutioninputgenerator_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,6 @@ def prepare_codegen_parallel(self):
dilation = self.get_nodeattr("Dilation")
simd = self.get_nodeattr("SIMD")
M = self.get_nodeattr("M")
depthwise = self.get_nodeattr("depthwise")

k_h, k_w = k
h, w = ifm_dim
Expand Down Expand Up @@ -713,7 +712,6 @@ def prepare_codegen_parallel(self):
]

# re-use default controller loop structure
code_gen_dict["$IS_DEPTHWISE$"] = ["1"] if depthwise else ["0"]
loop_h_iterations = out_dim_h
loop_w_iterations = out_dim_w
loop_kh_iterations = channel_factor
Expand All @@ -731,20 +729,14 @@ def prepare_codegen_parallel(self):
code_gen_dict["$INNERMOST_STATE$"] = ["STATE_LOOP_KH"]
loop_kh_iterations -= 1 # -1 because state is initial state

# set head and tail address increment values
tail_incr_w = (stride_w - 1) * channel_factor + 1
tail_incr_h = (
(skip_columns + (kernel_width - 1)) * channel_factor + 1
) + ( # remaining line
(stride_h - 1) * w * channel_factor
) # skip lines
tail_incr_last_window = stride_w * channel_factor

# set head address increment values
addr_incr_end_simd = 1
addr_incr_end_window_elem = 1
addr_incr_end_window_row = 1
addr_incr_end_window = tail_incr_w
addr_incr_end_row = tail_incr_h
addr_incr_end_window = (stride_w - 1) * channel_factor + 1
addr_incr_end_row = ((skip_columns + (kernel_width - 1)) * channel_factor + 1) + (
(stride_h - 1) * w * channel_factor
)

# add init value for CURRENT_ELEM counter = last elem of first window
code_gen_dict["$FIRST_WRITE_ELEM$"] = [str(buffer_min_size - 1)]
Expand Down Expand Up @@ -775,9 +767,6 @@ def prepare_codegen_parallel(self):
abs(addr_incr_end_window_row) + 1,
abs(addr_incr_end_window) + 1,
abs(addr_incr_end_row) + 1,
abs(tail_incr_w) + 1,
abs(tail_incr_h) + 1,
abs(tail_incr_last_window) + 1,
)
)
)
Expand All @@ -787,9 +776,11 @@ def prepare_codegen_parallel(self):
code_gen_dict["$HEAD_INCR_KH$"] = [str(addr_incr_end_window_row)]
code_gen_dict["$HEAD_INCR_W$"] = [str(addr_incr_end_window)]
code_gen_dict["$HEAD_INCR_H$"] = [str(addr_incr_end_row)]
code_gen_dict["$TAIL_INCR_W$"] = [str(tail_incr_w)]
code_gen_dict["$TAIL_INCR_H$"] = [str(tail_incr_h)]
code_gen_dict["$TAIL_INCR_LAST$"] = [str(tail_incr_last_window)]
# not used, set to zero:
code_gen_dict["$TAIL_INCR_W$"] = ["0"]
code_gen_dict["$TAIL_INCR_H$"] = ["0"]
code_gen_dict["$TAIL_INCR_LAST$"] = ["0"]
code_gen_dict["$IS_DEPTHWISE$"] = ["0"]

code_gen_dict["$SIMD$"] = [str(simd)]
code_gen_dict["$MMV_IN$"] = [str(mmv_in)]
Expand Down Expand Up @@ -968,8 +959,9 @@ def select_impl_style(self):
# choose implementation style
if mmv_out > 1 or (k_h == 1 and k_w == 1):
impl_style = "parallel"
if depthwise:
if depthwise or (k_h == 1 and k_w == 1):
# allow SIMD < IFM_CH in depthwise mode (VVAU supports the resulting data layout)
# also allowed for 1x1 kernel since depthwise and non-depthwise are equivalent
assert ifm_ch % simd == 0, "Constraint violated: SIMD must divide IFMChannels"
else:
assert ifm_ch == simd, "Constraint violated: SIMD must be equal to IFMChannels"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,9 @@ def test_fpgadataflow_slidingwindow_rtl(
pytest.skip("Illegal convolution configuration: kernel or stride > FM dimension")
if (k_h == 1 and dilation_h != 1) or (k_w == 1 and dilation_w != 1):
pytest.skip("Illegal convolution configuration: dilation for unitary kernel dim")
if (stride_h > k_h) or (stride_w > k_w) and not parallel_window:
if ((stride_h > k_h) or (stride_w > k_w)) and not (parallel_window or (k_h == 1 and k_w == 1)):
pytest.skip("Not all combinations for stride > k edge case supported in default mode")
if k_h == 1 and k_w == 1 and simd != ifm_ch:
pytest.skip("1x1 Kernel only supported in parallel mode (SIMD=C)")
if parallel_window and simd != ifm_ch and not dw:
if parallel_window and simd != ifm_ch and not (dw or (k_h == 1 and k_w == 1)):
pytest.skip("Parallel window requires SIMD=C for non-depthwise case")

ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, 0, dilation_h)
Expand Down

0 comments on commit b89dd62

Please sign in to comment.