diff --git a/odak/learn/wave/optimizers.py b/odak/learn/wave/optimizers.py index 416193b6..e2dbae5a 100644 --- a/odak/learn/wave/optimizers.py +++ b/odak/learn/wave/optimizers.py @@ -210,45 +210,6 @@ def evaluate(self, input_image, target_image, plane_id = 0): return loss - - def reconstruct(self, hologram_phases): - """ - Internal function to reconstruct a given hologram. - - - Parameters - ---------- - hologram_phases : torch.tensor - A monochrome hologram phase [mxn]. - - Returns - ------- - reconstruction_intensities : torch.tensor - Reconstructed frames. - reconstruction_intensity : torch.tensor - Reconstructed image. - peak_intensity : float - Peak intensity in the reconstructed image. - """ - torch.no_grad() - reconstruction_intensities = torch.zeros( - self.number_of_frames, - self.number_of_depth_layers, - self.number_of_channels, - self.resolution[0] * self.scale_factor, - self.resolution[1] * self.scale_factor, - device = self.device - ) - for frame_id in range(self.number_of_frames): - for depth_id in range(self.number_of_depth_layers): - for channel_id in range(self.number_of_channels): - laser_power = self.propagator_get_laser_powers()[frame_id][channel_id] - hologram = generate_complex_field(laser_power * self.amplitude, hologram_phases[frame_id] * self.phase_scale[channel_id]) - reconstruction_field = self.propagator(hologram, channel_id, depth_id) - reconstruction_intensities[frame_id, depth_id, channel_id] = calculate_amplitude(reconstruction_field) ** 2 - return reconstruction_intensities - - def double_phase_constrain(self, phase, phase_offset): """ Internal function to constrain a given phase similarly to double phase encoding. @@ -366,10 +327,11 @@ def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]): self.amplitude[1::self.scale_factor, 1::self.scale_factor] = 0. else: phase_scaled = phase - scaled_phase = phase_scaled * self.phase_scale[channel_id] laser_power = laser_powers[frame_id][channel_id] - amplitude = laser_power * self.amplitude - hologram = generate_complex_field(amplitude, scaled_phase) + hologram = generate_complex_field( + laser_power * self.amplitude, + phase_scaled * self.phase_scale[channel_id] + ) reconstruction_field = self.propagator(hologram, channel_id, depth_id) intensity = calculate_amplitude(reconstruction_field) ** 2 reconstruction_intensities[frame_id, channel_id] += intensity @@ -415,12 +377,10 @@ def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]): t.set_description(description) del total_loss del loss_image - del scaled_phase del reconstruction_field del reconstruction_intensities del intensity del phase - del amplitude del hologram logging.warning(description) return hologram_phases.detach() diff --git a/odak/learn/wave/propagators.py b/odak/learn/wave/propagators.py index 09ff442c..adedaa9c 100644 --- a/odak/learn/wave/propagators.py +++ b/odak/learn/wave/propagators.py @@ -351,5 +351,5 @@ def reconstruct(self, hologram_phases, amplitude = None, no_grad = True): frame_id, depth_id, channel_id - ] = calculate_amplitude(reconstruction_field) ** 2 + ] = calculate_amplitude(reconstruction_field).detach().clone() ** 2 return reconstruction_intensities