diff --git a/dflat/metasurface/optical_model.py b/dflat/metasurface/optical_model.py index f5cace9..ef97d0e 100644 --- a/dflat/metasurface/optical_model.py +++ b/dflat/metasurface/optical_model.py @@ -35,7 +35,7 @@ def training_step(self, x, y): pred = self.model(x) return self.loss(pred, y) - def forward(self, params, wavelength, pre_normalized=True): + def forward(self, params, wavelength, pre_normalized=True, batch_size=None): """Predict the cell optical response from the passed in design parameters and wavelength Args: @@ -44,6 +44,7 @@ def forward(self, params, wavelength, pre_normalized=True): pre_normalized (bool, optional): Flag to indicate if the passed in params and wavelength are already normalized to the range [0,1]. If False, the passed in tensors will be automatically normalized based on the config settings. Defaults to True. + batch_size (int, optional): Number of cells to evaluate at once via model batching. Returns: list: Amplitude and Phase of shape [B, pol, Lam, H, W] where pol is 1 or 2. @@ -62,12 +63,16 @@ def forward(self, params, wavelength, pre_normalized=True): torch_zero = torch.tensor(0.0, dtype=x.dtype).to(device=x.device) num_ch = x.shape[-1] + b, h, w, c = x.shape assert num_ch == ( len(self.param_bounds) - 1 ), "Channel dimension is inconsistent with loaded model" assert len(x.shape) == 4 assert len(lam.shape) == 1 - b, h, w, c = x.shape + if batch_size is not None: + assert isinstance(batch_size, int), "batch size must be an integer" + else: + batch_size = b * h * w if not pre_normalized: x = self.normalize(x) @@ -76,7 +81,13 @@ def forward(self, params, wavelength, pre_normalized=True): x = rearrange(x, "b h w c -> (b h w) c") out = [] for li in lam: - out.append(self.model(torch.cat((x, li.repeat(x.shape[0], 1)), dim=1))) + li_repeated = li.repeat(x.shape[0], 1) # Repeat `li` once for all rows in x + chout = [ + self.model(torch.cat((x[start:end], li_repeated[start:end]), dim=1)) + for start in range(0, x.shape[0], batch_size) + for end in [min(start + batch_size, x.shape[0])] + ] + out.append(torch.cat(chout, dim=0)) # Concatenate results for each `li` out = torch.stack(out) g = int(out.shape[-1] / 3) diff --git a/dflat/metasurface/reverse_lookup.py b/dflat/metasurface/reverse_lookup.py index 0599058..c54e969 100644 --- a/dflat/metasurface/reverse_lookup.py +++ b/dflat/metasurface/reverse_lookup.py @@ -16,6 +16,7 @@ def reverse_lookup_optimize( max_iter=1000, opt_phase_only=False, force_cpu=False, + batch_size=None, ): """Given a stack of wavelength dependent amplitude and phase profiles, runs a reverse optimization to identify the nanostructures that implements the desired profile across wavelength by minimizing the mean absolute errors of complex fields. @@ -28,6 +29,7 @@ def reverse_lookup_optimize( lr (float, optional): Optimization learning rate. Defaults to 1e-1. err_thresh (float, optional): Early termination threshold. Defaults to 0.1. max_iter (int, optional): Maximum number of steps. Defaults to 1000. + batch_size (int, optional): Number of cells to evaluate at once via model batching. Returns: list: Returns normalized and unnormalized metasurface design parameters of shape [B, H, W, D] where D is the number of shape parameters. Last item in list is the MAE loss for each step. @@ -48,9 +50,11 @@ def reverse_lookup_optimize( pg = model.dim_out // 3 assert pg == P, f"Polarization dimension of amp, phase (dim1) expected to be {pg}." - # z = np.random.rand(B, H, W, model.dim_in - 1) - z = np.zeros((B, H, W, model.dim_in - 1)) + shape_dim = model.dim_in - 1 + # z = np.random.rand(B, H, W, shape_dim) + z = np.zeros((B, H, W, shape_dim)) z = torch.tensor(z, device=device, dtype=torch.float32, requires_grad=True) + wavelength = ( torch.tensor(wavelength_set_m) if not torch.is_tensor(wavelength_set_m) @@ -58,11 +62,7 @@ def reverse_lookup_optimize( ) wavelength = wavelength.to(dtype=torch.float32, device=device) wavelength = model.normalize_wavelength(wavelength) - - # optimize - optimizer = optim.AdamW([z], lr=lr) torch_zero = torch.tensor(0.0, dtype=z.dtype, device=device) - amp = ( torch.tensor(amp, dtype=torch.float32, device=device) if not torch.is_tensor(amp) @@ -77,9 +77,11 @@ def reverse_lookup_optimize( torch.complex(torch_zero, phase) ) + # Optimize err = 1e3 steps = 0 err_list = [] + optimizer = optim.AdamW([z], lr=lr) pbar = tqdm(total=max_iter, desc="Optimization Progress") while err > err_thresh: if steps >= max_iter: @@ -88,7 +90,7 @@ def reverse_lookup_optimize( optimizer.zero_grad() pred_amp, pred_phase = model( - latent_to_param(z), wavelength, pre_normalized=True + latent_to_param(z), wavelength, pre_normalized=True, batch_size=batch_size ) if opt_phase_only: