Skip to content

Commit

Permalink
Merge pull request #31 from iksnagreb/fix/lookup
Browse files Browse the repository at this point in the history
[Lookup] Relax input datatype constraints
  • Loading branch information
fpjentzsch authored Feb 5, 2025
2 parents bfc66a0 + 3de81d0 commit 64282e5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
14 changes: 13 additions & 1 deletion src/finn/custom_op/fpgadataflow/hls/lookup_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import numpy as np
import os
import warnings
from math import ceil, log2
from qonnx.core.datatype import DataType

Expand Down Expand Up @@ -273,7 +274,18 @@ def execute_node(self, context, graph):
)

inp = context[node.input[0]]
assert inp.dtype == np.int64, "Inputs must be contained in int64 ndarray"

# Make sure the input has the right container datatype
if inp.dtype is not np.float32:
# Issue a warning to make the user aware of this type-cast
warnings.warn(
f"{node.name}: Changing input container datatype from "
f"{inp.dtype} to {np.float32}"
)
# Convert the input to floating point representation as the
# container datatype
inp = inp.astype(np.float32)

assert inp.shape == exp_ishape, """Input shape doesn't match expected shape."""
export_idt = self.get_input_datatype()
odt = self.get_output_datatype()
Expand Down
16 changes: 12 additions & 4 deletions src/finn/custom_op/fpgadataflow/rtl/streamingfifo_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,18 @@ def execute_node(self, context, graph):
elif mode == "rtlsim":
code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
# create a npy file for the input of the node
assert (
str(inp.dtype) == "float32"
), """Input datatype is
not float32 as expected."""

# Make sure the input has the right container datatype
if inp.dtype is not np.float32:
# Issue a warning to make the user aware of this type-cast
warnings.warn(
f"{node.name}: Changing input container datatype from "
f"{inp.dtype} to {np.float32}"
)
# Convert the input to floating point representation as the
# container datatype
inp = inp.astype(np.float32)

expected_inp_shape = self.get_folded_input_shape()
reshaped_input = inp.reshape(expected_inp_shape)
if DataType[self.get_nodeattr("dataType")] == DataType["BIPOLAR"]:
Expand Down

0 comments on commit 64282e5

Please sign in to comment.