Skip to content

Commit

Permalink
[Deconv] Update test and add comments to transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
auphelia committed Jan 3, 2024
1 parent c741fae commit 14929f4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
25 changes: 16 additions & 9 deletions src/finn/transformation/fpgadataflow/infer_pixel_padding_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@


class InferPixelPaddingDeconv(Transformation):
def __init__(self):
super().__init__()
"""
Lowering and conversion of ConvTranspose (NCHW) nodes to
FMPadding_Pixel + Im2Col + MatMul (NHWC) surrounded by Transpose nodes
note: this transformation produces a mix of hw layers and non hw layers
to implement this on an FPGA the Im2Col and MatMul nodes need to be converted to hw layers
after applying this transformation and the resulting transpose nodes need to be streamlined.
See deconv test case under tests/fpgadataflow for an example.
"""

def apply(self, model):
graph = model.graph
Expand All @@ -17,6 +23,14 @@ def apply(self, model):
for n in graph.node:
node_ind += 1
if n.op_type == "ConvTranspose":
# conversion currently only supported for group=1
group = get_by_name(n.attribute, "group").i
if group != 1:
warnings.warn(
"%s : Only group=1 is currently supported. Can't infer PixelPaddingDeconv."
% n.name
)
continue
deconv_input = n.input[0]
deconv_output = n.output[0]
idt = model.get_tensor_datatype(deconv_input)
Expand All @@ -25,13 +39,6 @@ def apply(self, model):
k_w = get_by_name(n.attribute, "kernel_shape").ints[1]
stride_h = get_by_name(n.attribute, "strides").ints[0]
stride_w = get_by_name(n.attribute, "strides").ints[1]
group = get_by_name(n.attribute, "group").i
if group != 1:
warnings.warn(
"%s : Only group=1 is currently supported. Can't infer PixelPaddingDeconv."
% n.name
)
continue
weight_name = n.input[1]
W_conv = model.get_initializer(weight_name)
ifm_ch = model.get_tensor_shape(n.input[0])[1] # assume NCHW
Expand Down
22 changes: 16 additions & 6 deletions tests/fpgadataflow/test_fpgadataflow_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import pytest

import numpy as np
import os
from onnx import TensorProto, helper
from qonnx.core.datatype import DataType
Expand All @@ -38,6 +39,7 @@
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model

import finn.core.onnx_exec as oxe
from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer
from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
from finn.transformation.fpgadataflow.convert_to_hls_layers import (
InferConvInpGen,
Expand Down Expand Up @@ -177,21 +179,29 @@ def test_fpgadataflow_deconv(idim, stride, ifm_ch, ofm_ch, simd, pe, k, padding,

expected_oshape = (1, ofm_ch, odim_h, odim_w)
y_expected = oxe.execute_onnx(ref_model, input_dict)["outp"]

# cppsim
if exec_mode == "cppsim":
model = model.transform(PrepareCppSim())
model = model.transform(CompileCppSim())
model = model.transform(SetExecMode("cppsim"))
y_produced = oxe.execute_onnx(model, input_dict)["outp"]
assert y_produced.shape == expected_oshape
assert (y_produced == y_expected).all()

# rtlsim
else:
model = model.transform(PrepareIP(test_fpga_part, target_clk_ns))
model = model.transform(HLSSynthIP())
model = model.transform(PrepareRTLSim())
model = model.transform(SetExecMode("rtlsim"))
y_produced = oxe.execute_onnx(model, input_dict)["outp"]
assert y_produced.shape == expected_oshape
assert (y_produced == y_expected).all()

y_produced = oxe.execute_onnx(model, input_dict)["outp"]
assert y_produced.shape == expected_oshape
assert (y_produced == y_expected).all()

if exec_mode == "rtlsim":
node = model.get_nodes_by_op_type("FMPadding_Pixel")[0]
inst = getCustomOp(node)
cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim")
exp_cycles_dict = model.analysis(exp_cycles_per_layer)
exp_cycles = exp_cycles_dict[node.name]
assert np.isclose(exp_cycles, cycles_rtlsim, atol=10)
assert exp_cycles != 0

0 comments on commit 14929f4

Please sign in to comment.