Skip to content

Commit

Permalink
Add batching
Browse files Browse the repository at this point in the history
  • Loading branch information
DeanHazineh committed Nov 7, 2024
1 parent 57f8942 commit d80320d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
17 changes: 14 additions & 3 deletions dflat/metasurface/optical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def training_step(self, x, y):
pred = self.model(x)
return self.loss(pred, y)

def forward(self, params, wavelength, pre_normalized=True):
def forward(self, params, wavelength, pre_normalized=True, batch_size=None):
"""Predict the cell optical response from the passed in design parameters and wavelength
Args:
Expand All @@ -44,6 +44,7 @@ def forward(self, params, wavelength, pre_normalized=True):
pre_normalized (bool, optional): Flag to indicate if the passed in params and wavelength are already normalized
to the range [0,1]. If False, the passed in tensors will be automatically normalized based on the config
settings. Defaults to True.
batch_size (int, optional): Number of cells to evaluate at once via model batching.
Returns:
list: Amplitude and Phase of shape [B, pol, Lam, H, W] where pol is 1 or 2.
Expand All @@ -62,12 +63,16 @@ def forward(self, params, wavelength, pre_normalized=True):
torch_zero = torch.tensor(0.0, dtype=x.dtype).to(device=x.device)

num_ch = x.shape[-1]
b, h, w, c = x.shape
assert num_ch == (
len(self.param_bounds) - 1
), "Channel dimension is inconsistent with loaded model"
assert len(x.shape) == 4
assert len(lam.shape) == 1
b, h, w, c = x.shape
if batch_size is not None:
assert isinstance(batch_size, int), "batch size must be an integer"
else:
batch_size = b * h * w

if not pre_normalized:
x = self.normalize(x)
Expand All @@ -76,7 +81,13 @@ def forward(self, params, wavelength, pre_normalized=True):
x = rearrange(x, "b h w c -> (b h w) c")
out = []
for li in lam:
out.append(self.model(torch.cat((x, li.repeat(x.shape[0], 1)), dim=1)))
li_repeated = li.repeat(x.shape[0], 1) # Repeat `li` once for all rows in x
chout = [
self.model(torch.cat((x[start:end], li_repeated[start:end]), dim=1))
for start in range(0, x.shape[0], batch_size)
for end in [min(start + batch_size, x.shape[0])]
]
out.append(torch.cat(chout, dim=0)) # Concatenate results for each `li`
out = torch.stack(out)

g = int(out.shape[-1] / 3)
Expand Down
16 changes: 9 additions & 7 deletions dflat/metasurface/reverse_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def reverse_lookup_optimize(
max_iter=1000,
opt_phase_only=False,
force_cpu=False,
batch_size=None,
):
"""Given a stack of wavelength dependent amplitude and phase profiles, runs a reverse optimization to identify the nanostructures that
implements the desired profile across wavelength by minimizing the mean absolute errors of complex fields.
Expand All @@ -28,6 +29,7 @@ def reverse_lookup_optimize(
lr (float, optional): Optimization learning rate. Defaults to 1e-1.
err_thresh (float, optional): Early termination threshold. Defaults to 0.1.
max_iter (int, optional): Maximum number of steps. Defaults to 1000.
batch_size (int, optional): Number of cells to evaluate at once via model batching.
Returns:
list: Returns normalized and unnormalized metasurface design parameters of shape [B, H, W, D] where D is the number of shape parameters. Last item in list is the MAE loss for each step.
Expand All @@ -48,21 +50,19 @@ def reverse_lookup_optimize(
pg = model.dim_out // 3
assert pg == P, f"Polarization dimension of amp, phase (dim1) expected to be {pg}."

# z = np.random.rand(B, H, W, model.dim_in - 1)
z = np.zeros((B, H, W, model.dim_in - 1))
shape_dim = model.dim_in - 1
# z = np.random.rand(B, H, W, shape_dim)
z = np.zeros((B, H, W, shape_dim))
z = torch.tensor(z, device=device, dtype=torch.float32, requires_grad=True)

wavelength = (
torch.tensor(wavelength_set_m)
if not torch.is_tensor(wavelength_set_m)
else wavelength_set_m
)
wavelength = wavelength.to(dtype=torch.float32, device=device)
wavelength = model.normalize_wavelength(wavelength)

# optimize
optimizer = optim.AdamW([z], lr=lr)
torch_zero = torch.tensor(0.0, dtype=z.dtype, device=device)

amp = (
torch.tensor(amp, dtype=torch.float32, device=device)
if not torch.is_tensor(amp)
Expand All @@ -77,9 +77,11 @@ def reverse_lookup_optimize(
torch.complex(torch_zero, phase)
)

# Optimize
err = 1e3
steps = 0
err_list = []
optimizer = optim.AdamW([z], lr=lr)
pbar = tqdm(total=max_iter, desc="Optimization Progress")
while err > err_thresh:
if steps >= max_iter:
Expand All @@ -88,7 +90,7 @@ def reverse_lookup_optimize(

optimizer.zero_grad()
pred_amp, pred_phase = model(
latent_to_param(z), wavelength, pre_normalized=True
latent_to_param(z), wavelength, pre_normalized=True, batch_size=batch_size
)

if opt_phase_only:
Expand Down

0 comments on commit d80320d

Please sign in to comment.