Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lookup] Relax input datatype constraints #31

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/finn/custom_op/fpgadataflow/hls/lookup_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def execute_node(self, context, graph):
)

inp = context[node.input[0]]
assert inp.dtype == np.int64, "Inputs must be contained in int64 ndarray"
# assert inp.dtype == np.int64, "Inputs must be contained in int64 ndarray"
assert inp.shape == exp_ishape, """Input shape doesn't match expected shape."""
export_idt = self.get_input_datatype()
odt = self.get_output_datatype()
Expand Down
16 changes: 12 additions & 4 deletions src/finn/custom_op/fpgadataflow/rtl/streamingfifo_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,18 @@ def execute_node(self, context, graph):
elif mode == "rtlsim":
code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
# create a npy file for the input of the node
assert (
str(inp.dtype) == "float32"
), """Input datatype is
not float32 as expected."""

# Make sure the inpout has the right container datatype
if inp.dtype != np.float32:
# Issue a warning to make the user aware of this type-cast
warnings.warn(
f"{node.name}: Changing input datatype from "
f"{inp.dtype} to {np.float32}"
)
# Convert the input to floating point representation as the
# container datatype
inp = inp.astype(np.float32)

expected_inp_shape = self.get_folded_input_shape()
reshaped_input = inp.reshape(expected_inp_shape)
if DataType[self.get_nodeattr("dataType")] == DataType["BIPOLAR"]:
Expand Down
Loading