Skip to content

Commit

Permalink
Fix bug in wavelength input as list
Browse files Browse the repository at this point in the history
  • Loading branch information
DeanHazineh committed Apr 19, 2024
1 parent 38ba4b2 commit 37c5d6a
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions dflat/metasurface/optical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,6 @@ def forward(self, params, wavelength, pre_normalized=True):
Returns:
list: Amplitude and Phase of shape [B, pol, Lam, H, W] where pol is 1 or 2.
"""
num_ch = params.shape[-1]
assert num_ch == (
len(self.param_bounds) - 1
), "Channel dimension is inconsistent with loaded model"
assert len(params.shape) == 4
assert len(wavelength.shape) == 1
b, h, w, c = params.shape

if not pre_normalized:
params = self.normalize(params)
wavelength = self.normalize_wavelength(wavelength)

device = "cuda" if torch.cuda.is_available() else "cpu"
x = (
torch.tensor(params, dtype=torch.float32).to(device)
Expand All @@ -65,6 +53,18 @@ def forward(self, params, wavelength, pre_normalized=True):
)
torch_zero = torch.tensor(0.0, dtype=x.dtype).to(device=x.device)

num_ch = x.shape[-1]
assert num_ch == (
len(self.param_bounds) - 1
), "Channel dimension is inconsistent with loaded model"
assert len(x.shape) == 4
assert len(lam.shape) == 1
b, h, w, c = x.shape

if not pre_normalized:
x = self.normalize(x)
lam = self.normalize_wavelength(lam)

x = rearrange(x, "b h w c -> (b h w) c")
out = []
for li in lam:
Expand Down

0 comments on commit 37c5d6a

Please sign in to comment.